[Misc] correct static type check for GroupCoordinator (#21946)

Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
Ning Xie 2025-08-05 18:17:46 +08:00 committed by GitHub
parent 83156c7b89
commit 74333ae2f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 4 deletions

View File

@ -70,6 +70,7 @@ class RayPPCommunicator(Communicator):
assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned"
self._comm = get_pp_group().device_communicator
assert self._comm is not None
# Since we wrap around the vLLM _PP communicator, we use
# the rank from the vLLM communicator, and ignore the rank

View File

@ -251,6 +251,7 @@ class EplbState:
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert ep_group is not None
assert global_expert_load.shape == (model.num_moe_layers,
model.num_logical_experts)
assert global_expert_load.dtype == torch.int64
@ -357,6 +358,7 @@ class EplbState:
# Collect load metrics from all ranks
ep_group = get_ep_group().device_group
assert ep_group is not None
num_tokens_list = [
torch.empty_like(num_tokens) for _ in range(ep_group.size())
]
@ -412,6 +414,7 @@ class EplbState:
"""
ep_group = get_ep_group().device_group
assert ep_group is not None
ep_rank = ep_group.rank()
time_start = None

View File

@ -196,10 +196,11 @@ class GroupCoordinator:
# 3 | 1 | 3 | 1 | 3
local_rank: int # local rank used to assign devices
rank_in_group: int # rank inside the group
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
cpu_group: Optional[ProcessGroup] # group for CPU communication
device_group: Optional[ProcessGroup] # group for device communication
use_device_communicator: bool # whether to use device communicator
device_communicator: DeviceCommunicatorBase # device communicator
device_communicator: Optional[
DeviceCommunicatorBase] # device communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
@ -250,7 +251,7 @@ class GroupCoordinator:
self.use_device_communicator = use_device_communicator
self.device_communicator: DeviceCommunicatorBase = None # type: ignore
self.device_communicator = None
if use_device_communicator and self.world_size > 1:
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls())
@ -364,6 +365,8 @@ class GroupCoordinator:
return self._all_reduce_out_place(input_)
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.all_reduce(input_)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
@ -384,12 +387,16 @@ class GroupCoordinator:
def _all_gather_out_place(self, input_: torch.Tensor,
dim: int) -> torch.Tensor:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.all_gather(input_, dim)
def all_gatherv(self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None):
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.all_gatherv(input_, dim, sizes)
def reduce_scatter(self,
@ -414,10 +421,14 @@ class GroupCoordinator:
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.reduce_scatterv(input_, dim, sizes)
def _reduce_scatter_out_place(self, input_: torch.Tensor,
dim: int) -> torch.Tensor:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.reduce_scatter(input_, dim)
def gather(self,
@ -433,6 +444,8 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.gather(input_, dst, dim)
def broadcast(self, input_: torch.Tensor, src: int = 0):
@ -667,6 +680,8 @@ class GroupCoordinator:
assert dst < self.world_size, f"Invalid dst rank ({dst})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
self.device_communicator.send_tensor_dict( # type: ignore
tensor_dict, dst)
return None
@ -727,6 +742,8 @@ class GroupCoordinator:
assert src < self.world_size, f"Invalid src rank ({src})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.recv_tensor_dict( # type: ignore
src)
@ -784,6 +801,8 @@ class GroupCoordinator:
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if self.device_communicator is None:
raise ValueError("No device communicator found")
self.device_communicator.send(tensor, dst)
def recv(self,
@ -792,6 +811,8 @@ class GroupCoordinator:
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.recv(size, dtype, src)
def destroy(self):