[Core][Test] move local_rank to the last arg with default value(#3711)

[Core][Test] move local_rank to the last arg with default value to keep api compatible (#3711)
This commit is contained in:
youkaichao 2024-03-28 21:19:45 -07:00 committed by GitHub
parent 395aa823ea
commit 756b30a5f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 24 additions and 14 deletions

View File

@ -24,7 +24,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
del os.environ["CUDA_VISIBLE_DEVICES"] del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank, rank, init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port) distributed_init_port)
num_elements = 8 num_elements = 8
all_tensors = [ all_tensors = [
@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
del os.environ["CUDA_VISIBLE_DEVICES"] del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank, rank, init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port) distributed_init_port)
num_dimensions = 3 num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2)) tensor_size = list(range(2, num_dimensions + 2))
@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
del os.environ["CUDA_VISIBLE_DEVICES"] del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank, rank, init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port) distributed_init_port)
test_dict = { test_dict = {
"a": torch.arange(8, dtype=torch.float32, device="cuda"), "a": torch.arange(8, dtype=torch.float32, device="cuda"),

View File

@ -23,7 +23,7 @@ def graph_allreduce(world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"] del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank, rank, init_test_distributed_environment(1, world_size, rank,
distributed_init_port) distributed_init_port)
custom_ar.init_custom_ar() custom_ar.init_custom_ar()
@ -58,7 +58,7 @@ def eager_allreduce(world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"] del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank, rank, init_test_distributed_environment(1, world_size, rank,
distributed_init_port) distributed_init_port)
sz = 1024 sz = 1024

View File

@ -14,7 +14,9 @@ def distributed_run(fn, world_size):
for i in range(number_of_processes): for i in range(number_of_processes):
env = os.environ.copy() env = os.environ.copy()
env['RANK'] = str(i) env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes) env['WORLD_SIZE'] = str(number_of_processes)
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost' env['MASTER_ADDR'] = 'localhost'
env['MASTER_PORT'] = '12345' env['MASTER_PORT'] = '12345'
p = multiprocessing.Process(target=fn, args=(env, )) p = multiprocessing.Process(target=fn, args=(env, ))

View File

@ -202,11 +202,11 @@ class NCCLCommunicator:
init_method=None, init_method=None,
timeout=datetime.timedelta(seconds=10), timeout=datetime.timedelta(seconds=10),
world_size: int = -1, world_size: int = -1,
local_rank: int = -1,
rank: int = -1, rank: int = -1,
store=None, store=None,
group_name: str = "", group_name: str = "",
pg_options=None, pg_options=None,
local_rank: int = -1,
): ):
if not dist.is_initialized(): if not dist.is_initialized():
backend = backend or "nccl" backend = backend or "nccl"
@ -220,6 +220,11 @@ class NCCLCommunicator:
store=store, store=store,
group_name=group_name, group_name=group_name,
pg_options=pg_options) pg_options=pg_options)
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
if local_rank == -1:
local_rank = self.rank
self.local_rank = local_rank
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
if rank == 0: if rank == 0:
self.unique_id = ncclGetUniqueId() self.unique_id = ncclGetUniqueId()

View File

@ -35,8 +35,10 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
pass pass
def init_process_group(world_size: int, local_rank: int, rank: int, def init_process_group(world_size: int,
init_method: str) -> None: rank: int,
init_method: str,
local_rank: int = -1) -> None:
assert not is_initialized() assert not is_initialized()
global comm global comm
logger.info(f"vLLM is using nccl=={ncclGetVersion()}") logger.info(f"vLLM is using nccl=={ncclGetVersion()}")

View File

@ -8,9 +8,9 @@ from vllm.worker.worker import init_distributed_environment
def init_test_distributed_environment( def init_test_distributed_environment(
pipeline_parallel_size: int, pipeline_parallel_size: int,
tensor_parallel_size: int, tensor_parallel_size: int,
local_rank: int,
rank: int, rank: int,
distributed_init_port: str, distributed_init_port: str,
local_rank: int = -1,
) -> None: ) -> None:
parallel_config = ParallelConfig(pipeline_parallel_size, parallel_config = ParallelConfig(pipeline_parallel_size,
tensor_parallel_size, tensor_parallel_size,
@ -18,9 +18,9 @@ def init_test_distributed_environment(
distributed_init_method = f"tcp://localhost:{distributed_init_port}" distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment( init_distributed_environment(
parallel_config, parallel_config,
local_rank,
rank, rank,
distributed_init_method=distributed_init_method) distributed_init_method=distributed_init_method,
local_rank=local_rank)
def multi_process_tensor_parallel( def multi_process_tensor_parallel(

View File

@ -97,8 +97,9 @@ class Worker:
raise RuntimeError( raise RuntimeError(
f"Not support device type: {self.device_config.device}") f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.local_rank, init_distributed_environment(self.parallel_config, self.rank,
self.rank, self.distributed_init_method) self.distributed_init_method,
self.local_rank)
# Set random seed. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
@ -249,9 +250,9 @@ class Worker:
def init_distributed_environment( def init_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
local_rank: int,
rank: int, rank: int,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():