diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index c415d409f7fe..979f2a06cec9 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -259,7 +259,6 @@ class EplbState: if global_expert_load is not None: ep_group = get_ep_group().device_group - assert ep_group is not None assert global_expert_load.shape == (model.num_moe_layers, model.num_logical_experts) assert global_expert_load.dtype == torch.int64 @@ -366,7 +365,6 @@ class EplbState: # Collect load metrics from all ranks ep_group = get_ep_group().device_group - assert ep_group is not None all_reduce(total_expert_load_pass, group=ep_group) # num_tokens_per_rank: (num_moe_layers, num_ranks) @@ -422,7 +420,6 @@ class EplbState: """ ep_group = get_ep_group().device_group - assert ep_group is not None ep_rank = ep_group.rank() time_start = None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0b3993ca0275..b89aee99c8d4 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -197,11 +197,10 @@ class GroupCoordinator: # 3 | 1 | 3 | 1 | 3 local_rank: int # local rank used to assign devices rank_in_group: int # rank inside the group - cpu_group: Optional[ProcessGroup] # group for CPU communication - device_group: Optional[ProcessGroup] # group for device communication - use_device_communicator: bool # whether to use device communicator - device_communicator: Optional[ - DeviceCommunicatorBase] # device communicator + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + # device communicator (if use_device_communicator=True) + device_communicator: Optional[DeviceCommunicatorBase] mq_broadcaster: Optional[Any] # shared memory broadcaster def __init__( @@ -209,7 +208,7 @@ class GroupCoordinator: group_ranks: list[list[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], - use_device_communicator: bool, + use_device_communicator: bool, # whether to use device communicator use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ): @@ -219,8 +218,9 @@ class GroupCoordinator: self.rank = torch.distributed.get_rank() self.local_rank = local_rank - self.device_group = None - self.cpu_group = None + + self_device_group = None + self_cpu_group = None for ranks in group_ranks: device_group = torch.distributed.new_group( @@ -232,11 +232,14 @@ class GroupCoordinator: self.ranks = ranks self.world_size = len(ranks) self.rank_in_group = ranks.index(self.rank) - self.device_group = device_group - self.cpu_group = cpu_group + self_device_group = device_group + self_cpu_group = cpu_group - assert self.cpu_group is not None - assert self.device_group is not None + assert self_cpu_group is not None + assert self_device_group is not None + + self.cpu_group = self_cpu_group + self.device_group = self_device_group from vllm.platforms import current_platform @@ -251,7 +254,6 @@ class GroupCoordinator: self.device = torch.device("cpu") self.use_device_communicator = use_device_communicator - self.device_communicator = None if use_device_communicator and self.world_size > 1: device_comm_cls = resolve_obj_by_qualname( @@ -817,12 +819,12 @@ class GroupCoordinator: return self.device_communicator.recv(size, dtype, src) def destroy(self): - if self.device_group is not None: + if hasattr(self, "device_group"): torch.distributed.destroy_process_group(self.device_group) - self.device_group = None - if self.cpu_group is not None: + del self.device_group + if hasattr(self, "cpu_group"): torch.distributed.destroy_process_group(self.cpu_group) - self.cpu_group = None + del self.cpu_group if self.device_communicator is not None: self.device_communicator.destroy() if self.mq_broadcaster is not None: