mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 07:45:01 +08:00
PP comm optimization: replace send with partial send + allgather (#6695)
Co-authored-by: Aurick Qiao <aurick.qiao@snowflake.com>
This commit is contained in:
parent
630dd9e0ae
commit
0437492ea9
@ -569,7 +569,8 @@ class GroupCoordinator:
|
|||||||
def send_tensor_dict(
|
def send_tensor_dict(
|
||||||
self,
|
self,
|
||||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
||||||
dst: Optional[int] = None
|
dst: Optional[int] = None,
|
||||||
|
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||||
"""Send the input tensor dictionary.
|
"""Send the input tensor dictionary.
|
||||||
NOTE: `dst` is the local rank of the source rank.
|
NOTE: `dst` is the local rank of the source rank.
|
||||||
@ -578,6 +579,11 @@ class GroupCoordinator:
|
|||||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||||
return tensor_dict
|
return tensor_dict
|
||||||
|
|
||||||
|
all_gather_size = (1 if all_gather_group is None else
|
||||||
|
all_gather_group.world_size)
|
||||||
|
all_gather_rank = (0 if all_gather_group is None else
|
||||||
|
all_gather_group.rank_in_group)
|
||||||
|
|
||||||
group = self.device_group
|
group = self.device_group
|
||||||
metadata_group = self.cpu_group
|
metadata_group = self.cpu_group
|
||||||
|
|
||||||
@ -598,6 +604,12 @@ class GroupCoordinator:
|
|||||||
if tensor.numel() == 0:
|
if tensor.numel() == 0:
|
||||||
# Skip sending empty tensors.
|
# Skip sending empty tensors.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# send-allgather: send only a slice, then do allgather.
|
||||||
|
if (all_gather_group is not None
|
||||||
|
and tensor.numel() % all_gather_size == 0):
|
||||||
|
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||||
|
|
||||||
if tensor.is_cpu:
|
if tensor.is_cpu:
|
||||||
# use metadata_group for CPU tensors
|
# use metadata_group for CPU tensors
|
||||||
torch.distributed.send(tensor,
|
torch.distributed.send(tensor,
|
||||||
@ -612,7 +624,8 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
def recv_tensor_dict(
|
def recv_tensor_dict(
|
||||||
self,
|
self,
|
||||||
src: Optional[int] = None
|
src: Optional[int] = None,
|
||||||
|
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||||
"""Recv the input tensor dictionary.
|
"""Recv the input tensor dictionary.
|
||||||
NOTE: `src` is the local rank of the source rank.
|
NOTE: `src` is the local rank of the source rank.
|
||||||
@ -621,6 +634,11 @@ class GroupCoordinator:
|
|||||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
all_gather_size = (1 if all_gather_group is None else
|
||||||
|
all_gather_group.world_size)
|
||||||
|
all_gather_rank = (0 if all_gather_group is None else
|
||||||
|
all_gather_group.rank_in_group)
|
||||||
|
|
||||||
group = self.device_group
|
group = self.device_group
|
||||||
metadata_group = self.cpu_group
|
metadata_group = self.cpu_group
|
||||||
|
|
||||||
@ -639,6 +657,16 @@ class GroupCoordinator:
|
|||||||
# Skip broadcasting empty tensors.
|
# Skip broadcasting empty tensors.
|
||||||
tensor_dict[key] = tensor
|
tensor_dict[key] = tensor
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# send-allgather: send only a slice, then do allgather.
|
||||||
|
use_all_gather = (all_gather_group is not None
|
||||||
|
and tensor.numel() % all_gather_size == 0)
|
||||||
|
|
||||||
|
if use_all_gather:
|
||||||
|
orig_shape = tensor.shape
|
||||||
|
tensor = tensor.reshape(all_gather_size,
|
||||||
|
-1)[all_gather_rank]
|
||||||
|
|
||||||
if tensor.is_cpu:
|
if tensor.is_cpu:
|
||||||
# use metadata_group for CPU tensors
|
# use metadata_group for CPU tensors
|
||||||
torch.distributed.recv(tensor,
|
torch.distributed.recv(tensor,
|
||||||
@ -649,6 +677,12 @@ class GroupCoordinator:
|
|||||||
torch.distributed.recv(tensor,
|
torch.distributed.recv(tensor,
|
||||||
src=self.ranks[src],
|
src=self.ranks[src],
|
||||||
group=group)
|
group=group)
|
||||||
|
if use_all_gather:
|
||||||
|
# do the allgather
|
||||||
|
tensor = all_gather_group.all_gather( # type: ignore
|
||||||
|
tensor, dim=0)
|
||||||
|
tensor = tensor.reshape(orig_shape)
|
||||||
|
|
||||||
tensor_dict[key] = tensor
|
tensor_dict[key] = tensor
|
||||||
else:
|
else:
|
||||||
tensor_dict[key] = value
|
tensor_dict[key] = value
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -267,7 +267,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
if not get_pp_group().is_first_rank:
|
if not get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = IntermediateTensors(
|
intermediate_tensors = IntermediateTensors(
|
||||||
get_pp_group().recv_tensor_dict())
|
get_pp_group().recv_tensor_dict(
|
||||||
|
all_gather_group=get_tp_group()))
|
||||||
|
|
||||||
output = self.model_runner.execute_model(
|
output = self.model_runner.execute_model(
|
||||||
model_input, self.kv_cache[worker_input.virtual_engine]
|
model_input, self.kv_cache[worker_input.virtual_engine]
|
||||||
@ -276,7 +277,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
# output is IntermediateTensors
|
# output is IntermediateTensors
|
||||||
get_pp_group().send_tensor_dict(output.tensors)
|
get_pp_group().send_tensor_dict(output.tensors,
|
||||||
|
all_gather_group=get_tp_group())
|
||||||
return [None]
|
return [None]
|
||||||
|
|
||||||
# output is List[SamplerOutput]
|
# output is List[SamplerOutput]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user