PP comm optimization: replace send with partial send + allgather (#6695)

Co-authored-by: Aurick Qiao <aurick.qiao@snowflake.com>
This commit is contained in:
Aurick Qiao 2024-07-31 20:15:42 -07:00 committed by GitHub
parent 630dd9e0ae
commit 0437492ea9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 5 deletions

View File

@ -569,7 +569,8 @@ class GroupCoordinator:
def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
@ -578,6 +579,11 @@ class GroupCoordinator:
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
@ -598,6 +604,12 @@ class GroupCoordinator:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
if (all_gather_group is not None
and tensor.numel() % all_gather_size == 0):
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor,
@ -612,7 +624,8 @@ class GroupCoordinator:
def recv_tensor_dict(
self,
src: Optional[int] = None
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
@ -621,6 +634,11 @@ class GroupCoordinator:
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
@ -639,6 +657,16 @@ class GroupCoordinator:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size,
-1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
@ -649,6 +677,12 @@ class GroupCoordinator:
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value

View File

@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
@ -267,7 +267,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict())
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
@ -276,7 +277,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
get_pp_group().send_tensor_dict(output.tensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return [None]
# output is List[SamplerOutput]