diff --git a/vllm/distributed/device_communicators/ray_communicator.py b/vllm/distributed/device_communicators/ray_communicator.py index e5ba297ebcc1..46cc1c2f52d6 100644 --- a/vllm/distributed/device_communicators/ray_communicator.py +++ b/vllm/distributed/device_communicators/ray_communicator.py @@ -70,6 +70,7 @@ class RayPPCommunicator(Communicator): assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned" self._comm = get_pp_group().device_communicator + assert self._comm is not None # Since we wrap around the vLLM _PP communicator, we use # the rank from the vLLM communicator, and ignore the rank diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index af6462084968..f64b516b0d04 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -251,6 +251,7 @@ class EplbState: if global_expert_load is not None: ep_group = get_ep_group().device_group + assert ep_group is not None assert global_expert_load.shape == (model.num_moe_layers, model.num_logical_experts) assert global_expert_load.dtype == torch.int64 @@ -357,6 +358,7 @@ class EplbState: # Collect load metrics from all ranks ep_group = get_ep_group().device_group + assert ep_group is not None num_tokens_list = [ torch.empty_like(num_tokens) for _ in range(ep_group.size()) ] @@ -412,6 +414,7 @@ class EplbState: """ ep_group = get_ep_group().device_group + assert ep_group is not None ep_rank = ep_group.rank() time_start = None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 470c1355d2a9..6c25cdcfb7b8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -196,10 +196,11 @@ class GroupCoordinator: # 3 | 1 | 3 | 1 | 3 local_rank: int # local rank used to assign devices rank_in_group: int # rank inside the group - cpu_group: ProcessGroup # group for CPU communication - device_group: ProcessGroup # group for device communication + cpu_group: Optional[ProcessGroup] # group for CPU communication + device_group: Optional[ProcessGroup] # group for device communication use_device_communicator: bool # whether to use device communicator - device_communicator: DeviceCommunicatorBase # device communicator + device_communicator: Optional[ + DeviceCommunicatorBase] # device communicator mq_broadcaster: Optional[Any] # shared memory broadcaster def __init__( @@ -250,7 +251,7 @@ class GroupCoordinator: self.use_device_communicator = use_device_communicator - self.device_communicator: DeviceCommunicatorBase = None # type: ignore + self.device_communicator = None if use_device_communicator and self.world_size > 1: device_comm_cls = resolve_obj_by_qualname( current_platform.get_device_communicator_cls()) @@ -364,6 +365,8 @@ class GroupCoordinator: return self._all_reduce_out_place(input_) def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.all_reduce(input_) def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: @@ -384,12 +387,16 @@ class GroupCoordinator: def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.all_gather(input_, dim) def all_gatherv(self, input_: Union[torch.Tensor, list[torch.Tensor]], dim: int = 0, sizes: Optional[list[int]] = None): + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.all_gatherv(input_, dim, sizes) def reduce_scatter(self, @@ -414,10 +421,14 @@ class GroupCoordinator: input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None) -> torch.Tensor: + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.reduce_scatterv(input_, dim, sizes) def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.reduce_scatter(input_, dim) def gather(self, @@ -433,6 +444,8 @@ class GroupCoordinator: # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.gather(input_, dst, dim) def broadcast(self, input_: torch.Tensor, src: int = 0): @@ -667,6 +680,8 @@ class GroupCoordinator: assert dst < self.world_size, f"Invalid dst rank ({dst})" if self.use_cpu_custom_send_recv: + if self.device_communicator is None: + raise ValueError("No device communicator found") self.device_communicator.send_tensor_dict( # type: ignore tensor_dict, dst) return None @@ -727,6 +742,8 @@ class GroupCoordinator: assert src < self.world_size, f"Invalid src rank ({src})" if self.use_cpu_custom_send_recv: + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.recv_tensor_dict( # type: ignore src) @@ -784,6 +801,8 @@ class GroupCoordinator: def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" + if self.device_communicator is None: + raise ValueError("No device communicator found") self.device_communicator.send(tensor, dst) def recv(self, @@ -792,6 +811,8 @@ class GroupCoordinator: src: Optional[int] = None) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.recv(size, dtype, src) def destroy(self):