mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[Core] Support multi-node inference(eager and cuda graph) (#3686)
This commit is contained in:
parent
a4075cba4d
commit
515386ef3c
@ -24,7 +24,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank, rank,
|
||||
distributed_init_port)
|
||||
num_elements = 8
|
||||
all_tensors = [
|
||||
@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank, rank,
|
||||
distributed_init_port)
|
||||
num_dimensions = 3
|
||||
tensor_size = list(range(2, num_dimensions + 2))
|
||||
@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank, rank,
|
||||
distributed_init_port)
|
||||
test_dict = {
|
||||
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||
|
||||
@ -23,7 +23,7 @@ def graph_allreduce(world_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, world_size, rank,
|
||||
init_test_distributed_environment(1, world_size, rank, rank,
|
||||
distributed_init_port)
|
||||
|
||||
custom_ar.init_custom_ar()
|
||||
@ -58,7 +58,7 @@ def eager_allreduce(world_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, world_size, rank,
|
||||
init_test_distributed_environment(1, world_size, rank, rank,
|
||||
distributed_init_port)
|
||||
|
||||
sz = 1024
|
||||
|
||||
@ -188,8 +188,6 @@ class RayGPUExecutor(ExecutorBase):
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
# FIXME(woosuk): We are not properly initializing pynccl when
|
||||
# we have multiple nodes.
|
||||
self._run_workers("init_device")
|
||||
self._run_workers(
|
||||
"load_model",
|
||||
|
||||
@ -202,6 +202,7 @@ class NCCLCommunicator:
|
||||
init_method=None,
|
||||
timeout=datetime.timedelta(seconds=10),
|
||||
world_size: int = -1,
|
||||
local_rank: int = -1,
|
||||
rank: int = -1,
|
||||
store=None,
|
||||
group_name: str = "",
|
||||
@ -219,25 +220,22 @@ class NCCLCommunicator:
|
||||
store=store,
|
||||
group_name=group_name,
|
||||
pg_options=pg_options)
|
||||
self.world_size = dist.get_world_size()
|
||||
self.rank = dist.get_rank()
|
||||
torch.cuda.set_device(self.rank)
|
||||
if self.rank == 0:
|
||||
torch.cuda.set_device(local_rank)
|
||||
if rank == 0:
|
||||
self.unique_id = ncclGetUniqueId()
|
||||
else:
|
||||
self.unique_id = NcclUniqueId()
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
|
||||
self.rank)
|
||||
tensor = torch.ByteTensor(list(
|
||||
self.unique_id.internal)).cuda(local_rank)
|
||||
dist.broadcast(tensor, src=0)
|
||||
byte_list = tensor.cpu().tolist()
|
||||
self.unique_id = NcclUniqueId()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
self.comm = ctypes.c_void_p()
|
||||
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
||||
self.unique_id, self.rank)
|
||||
result = _c_ncclCommInitRank(ctypes.byref(self.comm), world_size,
|
||||
self.unique_id, rank)
|
||||
assert result == 0
|
||||
self.stream = torch.cuda.Stream(device=f"cuda:{self.rank}")
|
||||
self.stream = torch.cuda.Stream(device=f"cuda:{local_rank}")
|
||||
|
||||
def all_reduce(self,
|
||||
tensor: torch.Tensor,
|
||||
|
||||
@ -36,11 +36,13 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
|
||||
pass
|
||||
|
||||
|
||||
def init_process_group(world_size: int, rank: int, init_method: str) -> None:
|
||||
def init_process_group(world_size: int, local_rank: int, rank: int,
|
||||
init_method: str) -> None:
|
||||
assert not is_initialized()
|
||||
global comm
|
||||
comm = NCCLCommunicator(init_method=init_method,
|
||||
world_size=world_size,
|
||||
local_rank=local_rank,
|
||||
rank=rank)
|
||||
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from vllm.worker.worker import init_distributed_environment
|
||||
def init_test_distributed_environment(
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_port: str,
|
||||
) -> None:
|
||||
@ -16,7 +17,10 @@ def init_test_distributed_environment(
|
||||
worker_use_ray=True)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
init_distributed_environment(
|
||||
parallel_config, rank, distributed_init_method=distributed_init_method)
|
||||
parallel_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method=distributed_init_method)
|
||||
|
||||
|
||||
def multi_process_tensor_parallel(
|
||||
|
||||
@ -97,8 +97,8 @@ class Worker:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method)
|
||||
init_distributed_environment(self.parallel_config, self.local_rank,
|
||||
self.rank, self.distributed_init_method)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@ -249,6 +249,7 @@ class Worker:
|
||||
|
||||
def init_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
) -> None:
|
||||
@ -282,9 +283,9 @@ def init_distributed_environment(
|
||||
elif parallel_config.world_size > 1:
|
||||
# NOTE(woosuk): We don't initialize pynccl process group when world size
|
||||
# is 1.
|
||||
# TODO(woosuk): Support multi-node connection.
|
||||
pynccl_utils.init_process_group(
|
||||
world_size=parallel_config.world_size,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user