diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 823ed56b23806..5b8370a56e488 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -19,7 +19,7 @@ import msgspec import numpy as np import torch import zmq - +from vllm import envs from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig @@ -201,9 +201,9 @@ class TransferError(MoRIIOError): def get_moriio_mode() -> MoRIIOMode: - read_mode = os.environ.get("MORIIO_CONNECTOR_READ_MODE", "false").lower() + read_mode=envs.VLLM_MORIIO_CONNECTOR_READ_MODE logger.debug("MoRIIO Connector read_mode: %s", read_mode) - if read_mode in ("true", "1", "yes", "on"): + if read_mode: return MoRIIOMode.READ else: return MoRIIOMode.WRITE @@ -575,9 +575,9 @@ class MoRIIOWrapper: def set_backend_type(self, backend_type): assert self.moriio_engine is not None, "MoRIIO engine must be set first" - qp_per_transfer = int(os.getenv("VLLM_MORI_QP_PER_TRANSFER", "1")) - post_batch_size = int(os.getenv("VLLM_MORI_POST_BATCH_SIZE", "-1")) - num_worker_threads = int(os.getenv("VLLM_MORI_NUM_WORKERS", "1")) + qp_per_transfer=envs.VLLM_MORIIO_QP_PER_TRANSFER + post_batch_size=envs.VLLM_MORIIO_POST_BATCH_SIZE + num_worker_threads=envs.VLLM_MORIIO_NUM_WORKERS poll_mode = PollCqMode.POLLING rdma_cfg = RdmaBackendConfig( qp_per_transfer, diff --git a/vllm/envs.py b/vllm/envs.py index 56558548d3981..f3b212dbe59cb 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -193,6 +193,10 @@ if TYPE_CHECKING: VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 + VLLM_MORIIO_CONNECTOR_READ_MODE:bool=False + VLLM_MORIIO_QP_PER_TRANSFER:int=1 + VLLM_MORIIO_POST_BATCH_SIZE:int=-1 + VLLM_MORIIO_NUM_WORKERS:int=1 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False @@ -1343,6 +1347,25 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") ), + # Controls the read mode for the Mori-IO connector + "VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: ( + os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() + in ("true", "1") + ), + # Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector + "VLLM_MORIIO_QP_PER_TRANSFER": lambda: int( + os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1") + ), + + # Controls the post-processing batch size for the Mori-IO connector + "VLLM_MORIIO_POST_BATCH_SIZE": lambda: int( + os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1") + ), + + # Controls the number of workers for Mori operations for the Mori-IO connector + "VLLM_MORIIO_NUM_WORKERS": lambda: int( + os.getenv("VLLM_MORIIO_NUM_WORKERS", "1") + ), # Controls whether or not to use cudnn prefill "VLLM_USE_CUDNN_PREFILL": lambda: bool( int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))