mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Misc] correct static type check for GroupCoordinator (#21946)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
parent
83156c7b89
commit
74333ae2f6
@ -70,6 +70,7 @@ class RayPPCommunicator(Communicator):
|
||||
assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned"
|
||||
|
||||
self._comm = get_pp_group().device_communicator
|
||||
assert self._comm is not None
|
||||
|
||||
# Since we wrap around the vLLM _PP communicator, we use
|
||||
# the rank from the vLLM communicator, and ignore the rank
|
||||
|
||||
@ -251,6 +251,7 @@ class EplbState:
|
||||
|
||||
if global_expert_load is not None:
|
||||
ep_group = get_ep_group().device_group
|
||||
assert ep_group is not None
|
||||
assert global_expert_load.shape == (model.num_moe_layers,
|
||||
model.num_logical_experts)
|
||||
assert global_expert_load.dtype == torch.int64
|
||||
@ -357,6 +358,7 @@ class EplbState:
|
||||
|
||||
# Collect load metrics from all ranks
|
||||
ep_group = get_ep_group().device_group
|
||||
assert ep_group is not None
|
||||
num_tokens_list = [
|
||||
torch.empty_like(num_tokens) for _ in range(ep_group.size())
|
||||
]
|
||||
@ -412,6 +414,7 @@ class EplbState:
|
||||
"""
|
||||
|
||||
ep_group = get_ep_group().device_group
|
||||
assert ep_group is not None
|
||||
ep_rank = ep_group.rank()
|
||||
|
||||
time_start = None
|
||||
|
||||
@ -196,10 +196,11 @@ class GroupCoordinator:
|
||||
# 3 | 1 | 3 | 1 | 3
|
||||
local_rank: int # local rank used to assign devices
|
||||
rank_in_group: int # rank inside the group
|
||||
cpu_group: ProcessGroup # group for CPU communication
|
||||
device_group: ProcessGroup # group for device communication
|
||||
cpu_group: Optional[ProcessGroup] # group for CPU communication
|
||||
device_group: Optional[ProcessGroup] # group for device communication
|
||||
use_device_communicator: bool # whether to use device communicator
|
||||
device_communicator: DeviceCommunicatorBase # device communicator
|
||||
device_communicator: Optional[
|
||||
DeviceCommunicatorBase] # device communicator
|
||||
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
||||
|
||||
def __init__(
|
||||
@ -250,7 +251,7 @@ class GroupCoordinator:
|
||||
|
||||
self.use_device_communicator = use_device_communicator
|
||||
|
||||
self.device_communicator: DeviceCommunicatorBase = None # type: ignore
|
||||
self.device_communicator = None
|
||||
if use_device_communicator and self.world_size > 1:
|
||||
device_comm_cls = resolve_obj_by_qualname(
|
||||
current_platform.get_device_communicator_cls())
|
||||
@ -364,6 +365,8 @@ class GroupCoordinator:
|
||||
return self._all_reduce_out_place(input_)
|
||||
|
||||
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.all_reduce(input_)
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
@ -384,12 +387,16 @@ class GroupCoordinator:
|
||||
|
||||
def _all_gather_out_place(self, input_: torch.Tensor,
|
||||
dim: int) -> torch.Tensor:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.all_gather(input_, dim)
|
||||
|
||||
def all_gatherv(self,
|
||||
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||
dim: int = 0,
|
||||
sizes: Optional[list[int]] = None):
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.all_gatherv(input_, dim, sizes)
|
||||
|
||||
def reduce_scatter(self,
|
||||
@ -414,10 +421,14 @@ class GroupCoordinator:
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1,
|
||||
sizes: Optional[list[int]] = None) -> torch.Tensor:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.reduce_scatterv(input_, dim, sizes)
|
||||
|
||||
def _reduce_scatter_out_place(self, input_: torch.Tensor,
|
||||
dim: int) -> torch.Tensor:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.reduce_scatter(input_, dim)
|
||||
|
||||
def gather(self,
|
||||
@ -433,6 +444,8 @@ class GroupCoordinator:
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.gather(input_, dst, dim)
|
||||
|
||||
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
||||
@ -667,6 +680,8 @@ class GroupCoordinator:
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
self.device_communicator.send_tensor_dict( # type: ignore
|
||||
tensor_dict, dst)
|
||||
return None
|
||||
@ -727,6 +742,8 @@ class GroupCoordinator:
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.recv_tensor_dict( # type: ignore
|
||||
src)
|
||||
|
||||
@ -784,6 +801,8 @@ class GroupCoordinator:
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
self.device_communicator.send(tensor, dst)
|
||||
|
||||
def recv(self,
|
||||
@ -792,6 +811,8 @@ class GroupCoordinator:
|
||||
src: Optional[int] = None) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.recv(size, dtype, src)
|
||||
|
||||
def destroy(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user