mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 19:14:57 +08:00
[Misc] correct static type check for GroupCoordinator (#21946)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
parent
83156c7b89
commit
74333ae2f6
@ -70,6 +70,7 @@ class RayPPCommunicator(Communicator):
|
|||||||
assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned"
|
assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned"
|
||||||
|
|
||||||
self._comm = get_pp_group().device_communicator
|
self._comm = get_pp_group().device_communicator
|
||||||
|
assert self._comm is not None
|
||||||
|
|
||||||
# Since we wrap around the vLLM _PP communicator, we use
|
# Since we wrap around the vLLM _PP communicator, we use
|
||||||
# the rank from the vLLM communicator, and ignore the rank
|
# the rank from the vLLM communicator, and ignore the rank
|
||||||
|
|||||||
@ -251,6 +251,7 @@ class EplbState:
|
|||||||
|
|
||||||
if global_expert_load is not None:
|
if global_expert_load is not None:
|
||||||
ep_group = get_ep_group().device_group
|
ep_group = get_ep_group().device_group
|
||||||
|
assert ep_group is not None
|
||||||
assert global_expert_load.shape == (model.num_moe_layers,
|
assert global_expert_load.shape == (model.num_moe_layers,
|
||||||
model.num_logical_experts)
|
model.num_logical_experts)
|
||||||
assert global_expert_load.dtype == torch.int64
|
assert global_expert_load.dtype == torch.int64
|
||||||
@ -357,6 +358,7 @@ class EplbState:
|
|||||||
|
|
||||||
# Collect load metrics from all ranks
|
# Collect load metrics from all ranks
|
||||||
ep_group = get_ep_group().device_group
|
ep_group = get_ep_group().device_group
|
||||||
|
assert ep_group is not None
|
||||||
num_tokens_list = [
|
num_tokens_list = [
|
||||||
torch.empty_like(num_tokens) for _ in range(ep_group.size())
|
torch.empty_like(num_tokens) for _ in range(ep_group.size())
|
||||||
]
|
]
|
||||||
@ -412,6 +414,7 @@ class EplbState:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
ep_group = get_ep_group().device_group
|
ep_group = get_ep_group().device_group
|
||||||
|
assert ep_group is not None
|
||||||
ep_rank = ep_group.rank()
|
ep_rank = ep_group.rank()
|
||||||
|
|
||||||
time_start = None
|
time_start = None
|
||||||
|
|||||||
@ -196,10 +196,11 @@ class GroupCoordinator:
|
|||||||
# 3 | 1 | 3 | 1 | 3
|
# 3 | 1 | 3 | 1 | 3
|
||||||
local_rank: int # local rank used to assign devices
|
local_rank: int # local rank used to assign devices
|
||||||
rank_in_group: int # rank inside the group
|
rank_in_group: int # rank inside the group
|
||||||
cpu_group: ProcessGroup # group for CPU communication
|
cpu_group: Optional[ProcessGroup] # group for CPU communication
|
||||||
device_group: ProcessGroup # group for device communication
|
device_group: Optional[ProcessGroup] # group for device communication
|
||||||
use_device_communicator: bool # whether to use device communicator
|
use_device_communicator: bool # whether to use device communicator
|
||||||
device_communicator: DeviceCommunicatorBase # device communicator
|
device_communicator: Optional[
|
||||||
|
DeviceCommunicatorBase] # device communicator
|
||||||
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -250,7 +251,7 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
self.use_device_communicator = use_device_communicator
|
self.use_device_communicator = use_device_communicator
|
||||||
|
|
||||||
self.device_communicator: DeviceCommunicatorBase = None # type: ignore
|
self.device_communicator = None
|
||||||
if use_device_communicator and self.world_size > 1:
|
if use_device_communicator and self.world_size > 1:
|
||||||
device_comm_cls = resolve_obj_by_qualname(
|
device_comm_cls = resolve_obj_by_qualname(
|
||||||
current_platform.get_device_communicator_cls())
|
current_platform.get_device_communicator_cls())
|
||||||
@ -364,6 +365,8 @@ class GroupCoordinator:
|
|||||||
return self._all_reduce_out_place(input_)
|
return self._all_reduce_out_place(input_)
|
||||||
|
|
||||||
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.device_communicator is None:
|
||||||
|
raise ValueError("No device communicator found")
|
||||||
return self.device_communicator.all_reduce(input_)
|
return self.device_communicator.all_reduce(input_)
|
||||||
|
|
||||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
@ -384,12 +387,16 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
def _all_gather_out_place(self, input_: torch.Tensor,
|
def _all_gather_out_place(self, input_: torch.Tensor,
|
||||||
dim: int) -> torch.Tensor:
|
dim: int) -> torch.Tensor:
|
||||||
|
if self.device_communicator is None:
|
||||||
|
raise ValueError("No device communicator found")
|
||||||
return self.device_communicator.all_gather(input_, dim)
|
return self.device_communicator.all_gather(input_, dim)
|
||||||
|
|
||||||
def all_gatherv(self,
|
def all_gatherv(self,
|
||||||
input_: Union[torch.Tensor, list[torch.Tensor]],
|
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
dim: int = 0,
|
dim: int = 0,
|
||||||
sizes: Optional[list[int]] = None):
|
sizes: Optional[list[int]] = None):
|
||||||
|
if self.device_communicator is None:
|
||||||
|
raise ValueError("No device communicator found")
|
||||||
return self.device_communicator.all_gatherv(input_, dim, sizes)
|
return self.device_communicator.all_gatherv(input_, dim, sizes)
|
||||||
|
|
||||||
def reduce_scatter(self,
|
def reduce_scatter(self,
|
||||||
@ -414,10 +421,14 @@ class GroupCoordinator:
|
|||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
dim: int = -1,
|
dim: int = -1,
|
||||||
sizes: Optional[list[int]] = None) -> torch.Tensor:
|
sizes: Optional[list[int]] = None) -> torch.Tensor:
|
||||||
|
if self.device_communicator is None:
|
||||||
|
raise ValueError("No device communicator found")
|
||||||
return self.device_communicator.reduce_scatterv(input_, dim, sizes)
|
return self.device_communicator.reduce_scatterv(input_, dim, sizes)
|
||||||
|
|
||||||
def _reduce_scatter_out_place(self, input_: torch.Tensor,
|
def _reduce_scatter_out_place(self, input_: torch.Tensor,
|
||||||
dim: int) -> torch.Tensor:
|
dim: int) -> torch.Tensor:
|
||||||
|
if self.device_communicator is None:
|
||||||
|
raise ValueError("No device communicator found")
|
||||||
return self.device_communicator.reduce_scatter(input_, dim)
|
return self.device_communicator.reduce_scatter(input_, dim)
|
||||||
|
|
||||||
def gather(self,
|
def gather(self,
|
||||||
@ -433,6 +444,8 @@ class GroupCoordinator:
|
|||||||
# Bypass the function if we are using only 1 GPU.
|
# Bypass the function if we are using only 1 GPU.
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return input_
|
return input_
|
||||||
|
if self.device_communicator is None:
|
||||||
|
raise ValueError("No device communicator found")
|
||||||
return self.device_communicator.gather(input_, dst, dim)
|
return self.device_communicator.gather(input_, dst, dim)
|
||||||
|
|
||||||
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
||||||
@ -667,6 +680,8 @@ class GroupCoordinator:
|
|||||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||||
|
|
||||||
if self.use_cpu_custom_send_recv:
|
if self.use_cpu_custom_send_recv:
|
||||||
|
if self.device_communicator is None:
|
||||||
|
raise ValueError("No device communicator found")
|
||||||
self.device_communicator.send_tensor_dict( # type: ignore
|
self.device_communicator.send_tensor_dict( # type: ignore
|
||||||
tensor_dict, dst)
|
tensor_dict, dst)
|
||||||
return None
|
return None
|
||||||
@ -727,6 +742,8 @@ class GroupCoordinator:
|
|||||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
if self.use_cpu_custom_send_recv:
|
if self.use_cpu_custom_send_recv:
|
||||||
|
if self.device_communicator is None:
|
||||||
|
raise ValueError("No device communicator found")
|
||||||
return self.device_communicator.recv_tensor_dict( # type: ignore
|
return self.device_communicator.recv_tensor_dict( # type: ignore
|
||||||
src)
|
src)
|
||||||
|
|
||||||
@ -784,6 +801,8 @@ class GroupCoordinator:
|
|||||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||||
"""Sends a tensor to the destination rank in a blocking way"""
|
"""Sends a tensor to the destination rank in a blocking way"""
|
||||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||||
|
if self.device_communicator is None:
|
||||||
|
raise ValueError("No device communicator found")
|
||||||
self.device_communicator.send(tensor, dst)
|
self.device_communicator.send(tensor, dst)
|
||||||
|
|
||||||
def recv(self,
|
def recv(self,
|
||||||
@ -792,6 +811,8 @@ class GroupCoordinator:
|
|||||||
src: Optional[int] = None) -> torch.Tensor:
|
src: Optional[int] = None) -> torch.Tensor:
|
||||||
"""Receives a tensor from the source rank."""
|
"""Receives a tensor from the source rank."""
|
||||||
"""NOTE: `src` is the local rank of the source rank."""
|
"""NOTE: `src` is the local rank of the source rank."""
|
||||||
|
if self.device_communicator is None:
|
||||||
|
raise ValueError("No device communicator found")
|
||||||
return self.device_communicator.recv(size, dtype, src)
|
return self.device_communicator.recv(size, dtype, src)
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user