[Misc] Further refine type annotations in parallel state (#22499)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-10 20:49:48 +08:00 committed by GitHub
parent 010e0e39ea
commit d411df0296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 20 deletions

View File

@ -259,7 +259,6 @@ class EplbState:
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert ep_group is not None
assert global_expert_load.shape == (model.num_moe_layers,
model.num_logical_experts)
assert global_expert_load.dtype == torch.int64
@ -366,7 +365,6 @@ class EplbState:
# Collect load metrics from all ranks
ep_group = get_ep_group().device_group
assert ep_group is not None
all_reduce(total_expert_load_pass, group=ep_group)
# num_tokens_per_rank: (num_moe_layers, num_ranks)
@ -422,7 +420,6 @@ class EplbState:
"""
ep_group = get_ep_group().device_group
assert ep_group is not None
ep_rank = ep_group.rank()
time_start = None

View File

@ -197,11 +197,10 @@ class GroupCoordinator:
# 3 | 1 | 3 | 1 | 3
local_rank: int # local rank used to assign devices
rank_in_group: int # rank inside the group
cpu_group: Optional[ProcessGroup] # group for CPU communication
device_group: Optional[ProcessGroup] # group for device communication
use_device_communicator: bool # whether to use device communicator
device_communicator: Optional[
DeviceCommunicatorBase] # device communicator
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
# device communicator (if use_device_communicator=True)
device_communicator: Optional[DeviceCommunicatorBase]
mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
@ -209,7 +208,7 @@ class GroupCoordinator:
group_ranks: list[list[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_device_communicator: bool,
use_device_communicator: bool, # whether to use device communicator
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
@ -219,8 +218,9 @@ class GroupCoordinator:
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
self_device_group = None
self_cpu_group = None
for ranks in group_ranks:
device_group = torch.distributed.new_group(
@ -232,11 +232,14 @@ class GroupCoordinator:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_group = device_group
self.cpu_group = cpu_group
self_device_group = device_group
self_cpu_group = cpu_group
assert self.cpu_group is not None
assert self.device_group is not None
assert self_cpu_group is not None
assert self_device_group is not None
self.cpu_group = self_cpu_group
self.device_group = self_device_group
from vllm.platforms import current_platform
@ -251,7 +254,6 @@ class GroupCoordinator:
self.device = torch.device("cpu")
self.use_device_communicator = use_device_communicator
self.device_communicator = None
if use_device_communicator and self.world_size > 1:
device_comm_cls = resolve_obj_by_qualname(
@ -817,12 +819,12 @@ class GroupCoordinator:
return self.device_communicator.recv(size, dtype, src)
def destroy(self):
if self.device_group is not None:
if hasattr(self, "device_group"):
torch.distributed.destroy_process_group(self.device_group)
self.device_group = None
if self.cpu_group is not None:
del self.device_group
if hasattr(self, "cpu_group"):
torch.distributed.destroy_process_group(self.cpu_group)
self.cpu_group = None
del self.cpu_group
if self.device_communicator is not None:
self.device_communicator.destroy()
if self.mq_broadcaster is not None: