[Refactor] p2p connector recv_ffn_output

This commit is contained in:
i-yuanyukun 2025-12-24 16:12:35 +08:00
parent 93c656e09b
commit 9f9a583f04
3 changed files with 12 additions and 43 deletions

View File

@ -225,7 +225,7 @@ class P2PAFDConnector(AFDConnectorBase):
src: int, src: int,
process_group: GroupCoordinator, process_group: GroupCoordinator,
tensor_metadata: TensorMetadata, tensor_metadata: TensorMetadata,
) -> tuple[torch.Tensor, list]: ) -> torch.Tensor:
if not torch.distributed.is_initialized() or process_group.world_size == 1: if not torch.distributed.is_initialized() or process_group.world_size == 1:
return {}, [] return {}, []
assert src < process_group.world_size, f"Invalid src rank ({src})" assert src < process_group.world_size, f"Invalid src rank ({src})"
@ -240,7 +240,7 @@ class P2PAFDConnector(AFDConnectorBase):
src=process_group.ranks[src], src=process_group.ranks[src],
group=process_group.device_group, group=process_group.device_group,
) )
return hidden_states, [] return hidden_states
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# attn -> ffn # attn -> ffn
@ -262,7 +262,7 @@ class P2PAFDConnector(AFDConnectorBase):
except Exception as e: except Exception as e:
raise RuntimeError(f"Communication error: {e}") raise RuntimeError(f"Communication error: {e}")
def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: def recv_ffn_output(self) -> torch.Tensor:
""" """
Called by the ATTN side to receive MOE output intermediate tensors, Called by the ATTN side to receive MOE output intermediate tensors,
possibly dispatching from the receiver to other GPUs. possibly dispatching from the receiver to other GPUs.
@ -272,16 +272,15 @@ class P2PAFDConnector(AFDConnectorBase):
self.recv_ffn_output_counter self.recv_ffn_output_counter
% self._current_afd_connector_metadata.num_of_stages % self._current_afd_connector_metadata.num_of_stages
) )
hidden_states, work_list = self._recv_hidden_states( hidden_states = self._recv_hidden_states(
src, src,
self.e2a_group, self.e2a_group,
self._tensor_metadata_list[stage_idx], self._tensor_metadata_list[stage_idx],
) )
self._current_afd_connector_metadata.recv_handle_list = work_list
self.recv_ffn_output_counter = ( self.recv_ffn_output_counter = (
self.recv_ffn_output_counter + 1 self.recv_ffn_output_counter + 1
) % self._current_afd_connector_metadata.num_of_stages ) % self._current_afd_connector_metadata.num_of_stages
return hidden_states, self._current_afd_connector_metadata return hidden_states
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# ffn -> attn # ffn -> attn
@ -328,12 +327,11 @@ class P2PAFDConnector(AFDConnectorBase):
self.recv_attn_output_counter self.recv_attn_output_counter
// self._current_afd_connector_metadata.num_of_stages // self._current_afd_connector_metadata.num_of_stages
) )
hidden_states, work_list = self._recv_hidden_states( hidden_states = self._recv_hidden_states(
src, src,
self.a2e_group, self.a2e_group,
self._tensor_metadata_list[stage_idx], self._tensor_metadata_list[stage_idx],
) )
self._current_afd_connector_metadata.recv_handle_list = work_list
self._current_afd_connector_metadata.layer_idx = layer_idx self._current_afd_connector_metadata.layer_idx = layer_idx
self._current_afd_connector_metadata.stage_idx = stage_idx self._current_afd_connector_metadata.stage_idx = stage_idx
return hidden_states, self._current_afd_connector_metadata return hidden_states, self._current_afd_connector_metadata

View File

@ -1347,19 +1347,13 @@ class DeepseekV2Model(nn.Module):
afd_metadata: AFDMetadata, afd_metadata: AFDMetadata,
llama_4_scaling: torch.Tensor | None = None, llama_4_scaling: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
recv_handle = None
for layer in islice(self.layers, self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
afd_connector = afd_metadata.afd_connector afd_connector = afd_metadata.afd_connector
afd_metadata.afd_stage_idx = dbo_current_ubatch_id() afd_metadata.afd_stage_idx = dbo_current_ubatch_id()
if layer.layer_idx > 0: if layer.layer_idx > 0:
hidden_states, recv_metadata = afd_connector.recv_ffn_output() hidden_states = afd_connector.recv_ffn_output()
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
if recv_handle is not None:
for work in recv_handle:
work.wait()
current_hidden, residual = layer( current_hidden, residual = layer(
positions, hidden_states, residual, llama_4_scaling positions, hidden_states, residual, llama_4_scaling
) )
@ -1377,12 +1371,7 @@ class DeepseekV2Model(nn.Module):
if dbo_enabled(): if dbo_enabled():
dbo_yield() dbo_yield()
hidden_states, recv_metadata = afd_connector.recv_ffn_output() hidden_states = afd_connector.recv_ffn_output()
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
if recv_handle is not None:
for work in recv_handle:
work.wait()
return hidden_states, residual return hidden_states, residual
@ -1395,7 +1384,6 @@ class DeepseekV2Model(nn.Module):
llama_4_scaling: torch.Tensor | None = None, llama_4_scaling: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
forward_conext = get_forward_context() forward_conext = get_forward_context()
recv_handle = None
ubatch_hidden_states = [] ubatch_hidden_states = []
ubatch_residual = [] ubatch_residual = []
@ -1421,16 +1409,10 @@ class DeepseekV2Model(nn.Module):
residual = ubatch_residual[stage_i] residual = ubatch_residual[stage_i]
if layer.layer_idx > 0: if layer.layer_idx > 0:
hidden_states, recv_metadata = afd_connector.recv_ffn_output() hidden_states = afd_connector.recv_ffn_output()
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
else: else:
hidden_states = ubatch_hidden_states[stage_i] hidden_states = ubatch_hidden_states[stage_i]
if recv_handle is not None:
for work in recv_handle:
work.wait()
current_positions = afd_metadata.positions_list[stage_i] current_positions = afd_metadata.positions_list[stage_i]
hidden_states, residual = layer( hidden_states, residual = layer(
current_positions, hidden_states, residual, llama_4_scaling current_positions, hidden_states, residual, llama_4_scaling
@ -1452,9 +1434,7 @@ class DeepseekV2Model(nn.Module):
# Recv last layer FFN output. # Recv last layer FFN output.
for stage_i in range(afd_metadata.num_of_stages): for stage_i in range(afd_metadata.num_of_stages):
ubatch_hidden_states[stage_i], recv_metadata = ( ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output()
afd_connector.recv_ffn_output()
)
# Re-assemble the batch # Re-assemble the batch
hidden_states = torch.cat(ubatch_hidden_states, dim=0) hidden_states = torch.cat(ubatch_hidden_states, dim=0)

View File

@ -385,7 +385,6 @@ class Step3TextModel(nn.Module):
afd_metadata: AFDMetadata, afd_metadata: AFDMetadata,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
forward_conext = get_forward_context() forward_conext = get_forward_context()
recv_handle = None
ubatch_hidden_states = [] ubatch_hidden_states = []
ubatch_residual = [] ubatch_residual = []
@ -409,16 +408,10 @@ class Step3TextModel(nn.Module):
residual = ubatch_residual[stage_i] residual = ubatch_residual[stage_i]
if layer.layer_idx > 0: if layer.layer_idx > 0:
hidden_states, recv_metadata = afd_connector.recv_ffn_output() hidden_states = afd_connector.recv_ffn_output()
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
else: else:
hidden_states = ubatch_hidden_states[stage_i] hidden_states = ubatch_hidden_states[stage_i]
if recv_handle is not None:
for work in recv_handle:
work.wait()
current_positions = afd_metadata.positions_list[stage_i] current_positions = afd_metadata.positions_list[stage_i]
hidden_states, residual = layer( hidden_states, residual = layer(
current_positions, hidden_states, residual current_positions, hidden_states, residual
@ -439,9 +432,7 @@ class Step3TextModel(nn.Module):
# Recv last layer FFN output. # Recv last layer FFN output.
for stage_i in range(afd_metadata.num_of_stages): for stage_i in range(afd_metadata.num_of_stages):
ubatch_hidden_states[stage_i], recv_metadata = ( ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output()
afd_connector.recv_ffn_output()
)
# Re-assemble the batch # Re-assemble the batch
hidden_states = torch.cat(ubatch_hidden_states, dim=0) hidden_states = torch.cat(ubatch_hidden_states, dim=0)