mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[Misc] Further refine type annotations in parallel state (#22499)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
010e0e39ea
commit
d411df0296
@ -259,7 +259,6 @@ class EplbState:
|
||||
|
||||
if global_expert_load is not None:
|
||||
ep_group = get_ep_group().device_group
|
||||
assert ep_group is not None
|
||||
assert global_expert_load.shape == (model.num_moe_layers,
|
||||
model.num_logical_experts)
|
||||
assert global_expert_load.dtype == torch.int64
|
||||
@ -366,7 +365,6 @@ class EplbState:
|
||||
|
||||
# Collect load metrics from all ranks
|
||||
ep_group = get_ep_group().device_group
|
||||
assert ep_group is not None
|
||||
all_reduce(total_expert_load_pass, group=ep_group)
|
||||
|
||||
# num_tokens_per_rank: (num_moe_layers, num_ranks)
|
||||
@ -422,7 +420,6 @@ class EplbState:
|
||||
"""
|
||||
|
||||
ep_group = get_ep_group().device_group
|
||||
assert ep_group is not None
|
||||
ep_rank = ep_group.rank()
|
||||
|
||||
time_start = None
|
||||
|
||||
@ -197,11 +197,10 @@ class GroupCoordinator:
|
||||
# 3 | 1 | 3 | 1 | 3
|
||||
local_rank: int # local rank used to assign devices
|
||||
rank_in_group: int # rank inside the group
|
||||
cpu_group: Optional[ProcessGroup] # group for CPU communication
|
||||
device_group: Optional[ProcessGroup] # group for device communication
|
||||
use_device_communicator: bool # whether to use device communicator
|
||||
device_communicator: Optional[
|
||||
DeviceCommunicatorBase] # device communicator
|
||||
cpu_group: ProcessGroup # group for CPU communication
|
||||
device_group: ProcessGroup # group for device communication
|
||||
# device communicator (if use_device_communicator=True)
|
||||
device_communicator: Optional[DeviceCommunicatorBase]
|
||||
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
||||
|
||||
def __init__(
|
||||
@ -209,7 +208,7 @@ class GroupCoordinator:
|
||||
group_ranks: list[list[int]],
|
||||
local_rank: int,
|
||||
torch_distributed_backend: Union[str, Backend],
|
||||
use_device_communicator: bool,
|
||||
use_device_communicator: bool, # whether to use device communicator
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
):
|
||||
@ -219,8 +218,9 @@ class GroupCoordinator:
|
||||
|
||||
self.rank = torch.distributed.get_rank()
|
||||
self.local_rank = local_rank
|
||||
self.device_group = None
|
||||
self.cpu_group = None
|
||||
|
||||
self_device_group = None
|
||||
self_cpu_group = None
|
||||
|
||||
for ranks in group_ranks:
|
||||
device_group = torch.distributed.new_group(
|
||||
@ -232,11 +232,14 @@ class GroupCoordinator:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
self.device_group = device_group
|
||||
self.cpu_group = cpu_group
|
||||
self_device_group = device_group
|
||||
self_cpu_group = cpu_group
|
||||
|
||||
assert self.cpu_group is not None
|
||||
assert self.device_group is not None
|
||||
assert self_cpu_group is not None
|
||||
assert self_device_group is not None
|
||||
|
||||
self.cpu_group = self_cpu_group
|
||||
self.device_group = self_device_group
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -251,7 +254,6 @@ class GroupCoordinator:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.use_device_communicator = use_device_communicator
|
||||
|
||||
self.device_communicator = None
|
||||
if use_device_communicator and self.world_size > 1:
|
||||
device_comm_cls = resolve_obj_by_qualname(
|
||||
@ -817,12 +819,12 @@ class GroupCoordinator:
|
||||
return self.device_communicator.recv(size, dtype, src)
|
||||
|
||||
def destroy(self):
|
||||
if self.device_group is not None:
|
||||
if hasattr(self, "device_group"):
|
||||
torch.distributed.destroy_process_group(self.device_group)
|
||||
self.device_group = None
|
||||
if self.cpu_group is not None:
|
||||
del self.device_group
|
||||
if hasattr(self, "cpu_group"):
|
||||
torch.distributed.destroy_process_group(self.cpu_group)
|
||||
self.cpu_group = None
|
||||
del self.cpu_group
|
||||
if self.device_communicator is not None:
|
||||
self.device_communicator.destroy()
|
||||
if self.mq_broadcaster is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user