diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index f85679400d1c6..a48c6e0e752d1 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -225,7 +225,7 @@ class P2PAFDConnector(AFDConnectorBase): src: int, process_group: GroupCoordinator, tensor_metadata: TensorMetadata, - ) -> tuple[torch.Tensor, list]: + ) -> torch.Tensor: if not torch.distributed.is_initialized() or process_group.world_size == 1: return {}, [] assert src < process_group.world_size, f"Invalid src rank ({src})" @@ -240,7 +240,7 @@ class P2PAFDConnector(AFDConnectorBase): src=process_group.ranks[src], group=process_group.device_group, ) - return hidden_states, [] + return hidden_states # ------------------------------------------------------------------------- # attn -> ffn @@ -262,7 +262,7 @@ class P2PAFDConnector(AFDConnectorBase): except Exception as e: raise RuntimeError(f"Communication error: {e}") - def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: + def recv_ffn_output(self) -> torch.Tensor: """ Called by the ATTN side to receive MOE output intermediate tensors, possibly dispatching from the receiver to other GPUs. @@ -272,16 +272,15 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_ffn_output_counter % self._current_afd_connector_metadata.num_of_stages ) - hidden_states, work_list = self._recv_hidden_states( + hidden_states = self._recv_hidden_states( src, self.e2a_group, self._tensor_metadata_list[stage_idx], ) - self._current_afd_connector_metadata.recv_handle_list = work_list self.recv_ffn_output_counter = ( self.recv_ffn_output_counter + 1 ) % self._current_afd_connector_metadata.num_of_stages - return hidden_states, self._current_afd_connector_metadata + return hidden_states # ------------------------------------------------------------------------- # ffn -> attn @@ -328,12 +327,11 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_attn_output_counter // self._current_afd_connector_metadata.num_of_stages ) - hidden_states, work_list = self._recv_hidden_states( + hidden_states = self._recv_hidden_states( src, self.a2e_group, self._tensor_metadata_list[stage_idx], ) - self._current_afd_connector_metadata.recv_handle_list = work_list self._current_afd_connector_metadata.layer_idx = layer_idx self._current_afd_connector_metadata.stage_idx = stage_idx return hidden_states, self._current_afd_connector_metadata diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 6b7842b00f54a..9b9cf7a59126f 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1347,19 +1347,13 @@ class DeepseekV2Model(nn.Module): afd_metadata: AFDMetadata, llama_4_scaling: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - recv_handle = None for layer in islice(self.layers, self.start_layer, self.end_layer): afd_connector = afd_metadata.afd_connector afd_metadata.afd_stage_idx = dbo_current_ubatch_id() if layer.layer_idx > 0: - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list + hidden_states = afd_connector.recv_ffn_output() - if recv_handle is not None: - for work in recv_handle: - work.wait() current_hidden, residual = layer( positions, hidden_states, residual, llama_4_scaling ) @@ -1377,12 +1371,7 @@ class DeepseekV2Model(nn.Module): if dbo_enabled(): dbo_yield() - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list - if recv_handle is not None: - for work in recv_handle: - work.wait() + hidden_states = afd_connector.recv_ffn_output() return hidden_states, residual @@ -1395,7 +1384,6 @@ class DeepseekV2Model(nn.Module): llama_4_scaling: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: forward_conext = get_forward_context() - recv_handle = None ubatch_hidden_states = [] ubatch_residual = [] @@ -1421,16 +1409,10 @@ class DeepseekV2Model(nn.Module): residual = ubatch_residual[stage_i] if layer.layer_idx > 0: - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list + hidden_states = afd_connector.recv_ffn_output() else: hidden_states = ubatch_hidden_states[stage_i] - if recv_handle is not None: - for work in recv_handle: - work.wait() - current_positions = afd_metadata.positions_list[stage_i] hidden_states, residual = layer( current_positions, hidden_states, residual, llama_4_scaling @@ -1452,9 +1434,7 @@ class DeepseekV2Model(nn.Module): # Recv last layer FFN output. for stage_i in range(afd_metadata.num_of_stages): - ubatch_hidden_states[stage_i], recv_metadata = ( - afd_connector.recv_ffn_output() - ) + ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output() # Re-assemble the batch hidden_states = torch.cat(ubatch_hidden_states, dim=0) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index eb723582a0ccb..6e87aeab49b3e 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -385,7 +385,6 @@ class Step3TextModel(nn.Module): afd_metadata: AFDMetadata, ) -> tuple[torch.Tensor, torch.Tensor]: forward_conext = get_forward_context() - recv_handle = None ubatch_hidden_states = [] ubatch_residual = [] @@ -409,16 +408,10 @@ class Step3TextModel(nn.Module): residual = ubatch_residual[stage_i] if layer.layer_idx > 0: - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list + hidden_states = afd_connector.recv_ffn_output() else: hidden_states = ubatch_hidden_states[stage_i] - if recv_handle is not None: - for work in recv_handle: - work.wait() - current_positions = afd_metadata.positions_list[stage_i] hidden_states, residual = layer( current_positions, hidden_states, residual @@ -439,9 +432,7 @@ class Step3TextModel(nn.Module): # Recv last layer FFN output. for stage_i in range(afd_metadata.num_of_stages): - ubatch_hidden_states[stage_i], recv_metadata = ( - afd_connector.recv_ffn_output() - ) + ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output() # Re-assemble the batch hidden_states = torch.cat(ubatch_hidden_states, dim=0)