mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 04:02:18 +08:00
[Misc]allow disable pynccl (#25421)
Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com>
This commit is contained in:
parent
2a69ab4899
commit
f48b6a03ba
@ -147,6 +147,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
assert out is not None
|
assert out is not None
|
||||||
return out
|
return out
|
||||||
pynccl_comm = self.pynccl_comm
|
pynccl_comm = self.pynccl_comm
|
||||||
|
if pynccl_comm is None or pynccl_comm.disabled:
|
||||||
|
out = input_.clone()
|
||||||
|
torch.distributed.all_reduce(out, group=self.device_group)
|
||||||
|
return out
|
||||||
assert pynccl_comm is not None
|
assert pynccl_comm is not None
|
||||||
out = pynccl_comm.all_reduce(input_)
|
out = pynccl_comm.all_reduce(input_)
|
||||||
if out is None:
|
if out is None:
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup, ReduceOp
|
from torch.distributed import ProcessGroup, ReduceOp
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
||||||
ncclRedOpTypeEnum, ncclUniqueId)
|
ncclRedOpTypeEnum, ncclUniqueId)
|
||||||
@ -83,7 +84,7 @@ class PyNcclCommunicator:
|
|||||||
self.group = group
|
self.group = group
|
||||||
|
|
||||||
# if world_size == 1, no need to create communicator
|
# if world_size == 1, no need to create communicator
|
||||||
if self.world_size == 1:
|
if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
|
||||||
self.available = False
|
self.available = False
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
return
|
return
|
||||||
|
|||||||
@ -98,6 +98,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_SKIP_P2P_CHECK: bool = False
|
VLLM_SKIP_P2P_CHECK: bool = False
|
||||||
VLLM_DISABLED_KERNELS: list[str] = []
|
VLLM_DISABLED_KERNELS: list[str] = []
|
||||||
VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False
|
VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False
|
||||||
|
VLLM_DISABLE_PYNCCL: bool = False
|
||||||
VLLM_USE_V1: bool = True
|
VLLM_USE_V1: bool = True
|
||||||
VLLM_ROCM_USE_AITER: bool = False
|
VLLM_ROCM_USE_AITER: bool = False
|
||||||
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
|
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
|
||||||
@ -897,6 +898,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
(os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower() in
|
(os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower() in
|
||||||
("true", "1")),
|
("true", "1")),
|
||||||
|
|
||||||
|
# Disable pynccl (using torch.distributed instead)
|
||||||
|
"VLLM_DISABLE_PYNCCL":
|
||||||
|
lambda:
|
||||||
|
(os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")),
|
||||||
|
|
||||||
# If set, use the V1 code path.
|
# If set, use the V1 code path.
|
||||||
"VLLM_USE_V1":
|
"VLLM_USE_V1":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),
|
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user