diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 58be99bf117ea..0605facbfaffb 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -139,9 +139,11 @@ class P2PAFDConnector(AFDConnectorBase): for idx in range(num_of_stages): if idx == 0: tensor_metadata_list[0] = tensor_metadata + logger.info(f"build tensor metadata: stage_{idx=}, size={tensor_metadata.size}") else: new_size = list(tensor_metadata.size) new_size[0] = connector_metadata.afd_tokens_lens[idx] + logger.info(f"build tensor metadata: stage_{idx=}, {new_size=}, {connector_metadata.afd_tokens_lens=}") tensor_metadata_list[idx] = TensorMetadata( tensor_metadata.device, tensor_metadata.dtype, @@ -165,6 +167,7 @@ class P2PAFDConnector(AFDConnectorBase): ) metadata_tuple = (metadata, tensor_metadata) process_group.send_object(metadata_tuple, dst=dst) + logger.info(f"_send_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, metadata ) @@ -177,6 +180,7 @@ class P2PAFDConnector(AFDConnectorBase): (self._current_afd_connector_metadata, tensor_metadata) = ( process_group.recv_object(src=src) ) + logger.info(f"_recv_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, self._current_afd_connector_metadata ) @@ -225,6 +229,7 @@ class P2PAFDConnector(AFDConnectorBase): dtype=tensor_metadata.dtype, device=tensor_metadata.device, ) + # logger.info(f"{__file__}: p2p recv hidden states: {hidden_states.shape=}, {tensor_metadata.size=}") torch.distributed.recv( hidden_states, src=process_group.ranks[src], @@ -262,6 +267,7 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_ffn_output_counter % self._current_afd_connector_metadata.num_of_stages ) + logger.info(f"{stage_idx=}") hidden_states, work_list = self._recv_hidden_states( src, self.e2a_group,