Merge pull request #21 from jiangkuaixue123/afd-refactor-p2p-connector

[Refactor] AFD p2pconnector recv_ffn_output
This commit is contained in:
jiangkuaixue123 2025-12-24 16:44:55 +08:00 committed by GitHub
commit e7254d8994
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 43 deletions

View File

@ -225,7 +225,7 @@ class P2PAFDConnector(AFDConnectorBase):
src: int, src: int,
process_group: GroupCoordinator, process_group: GroupCoordinator,
tensor_metadata: TensorMetadata, tensor_metadata: TensorMetadata,
) -> tuple[torch.Tensor, list]: ) -> torch.Tensor:
if not torch.distributed.is_initialized() or process_group.world_size == 1: if not torch.distributed.is_initialized() or process_group.world_size == 1:
return {}, [] return {}, []
assert src < process_group.world_size, f"Invalid src rank ({src})" assert src < process_group.world_size, f"Invalid src rank ({src})"
@ -240,7 +240,7 @@ class P2PAFDConnector(AFDConnectorBase):
src=process_group.ranks[src], src=process_group.ranks[src],
group=process_group.device_group, group=process_group.device_group,
) )
return hidden_states, [] return hidden_states
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# attn -> ffn # attn -> ffn
@ -262,7 +262,7 @@ class P2PAFDConnector(AFDConnectorBase):
except Exception as e: except Exception as e:
raise RuntimeError(f"Communication error: {e}") raise RuntimeError(f"Communication error: {e}")
def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: def recv_ffn_output(self) -> torch.Tensor:
""" """
Called by the ATTN side to receive MOE output intermediate tensors, Called by the ATTN side to receive MOE output intermediate tensors,
possibly dispatching from the receiver to other GPUs. possibly dispatching from the receiver to other GPUs.
@ -272,16 +272,15 @@ class P2PAFDConnector(AFDConnectorBase):
self.recv_ffn_output_counter self.recv_ffn_output_counter
% self._current_afd_connector_metadata.num_of_stages % self._current_afd_connector_metadata.num_of_stages
) )
hidden_states, work_list = self._recv_hidden_states( hidden_states = self._recv_hidden_states(
src, src,
self.e2a_group, self.e2a_group,
self._tensor_metadata_list[stage_idx], self._tensor_metadata_list[stage_idx],
) )
self._current_afd_connector_metadata.recv_handle_list = work_list
self.recv_ffn_output_counter = ( self.recv_ffn_output_counter = (
self.recv_ffn_output_counter + 1 self.recv_ffn_output_counter + 1
) % self._current_afd_connector_metadata.num_of_stages ) % self._current_afd_connector_metadata.num_of_stages
return hidden_states, self._current_afd_connector_metadata return hidden_states
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# ffn -> attn # ffn -> attn
@ -328,12 +327,11 @@ class P2PAFDConnector(AFDConnectorBase):
self.recv_attn_output_counter self.recv_attn_output_counter
// self._current_afd_connector_metadata.num_of_stages // self._current_afd_connector_metadata.num_of_stages
) )
hidden_states, work_list = self._recv_hidden_states( hidden_states = self._recv_hidden_states(
src, src,
self.a2e_group, self.a2e_group,
self._tensor_metadata_list[stage_idx], self._tensor_metadata_list[stage_idx],
) )
self._current_afd_connector_metadata.recv_handle_list = work_list
self._current_afd_connector_metadata.layer_idx = layer_idx self._current_afd_connector_metadata.layer_idx = layer_idx
self._current_afd_connector_metadata.stage_idx = stage_idx self._current_afd_connector_metadata.stage_idx = stage_idx
return hidden_states, self._current_afd_connector_metadata return hidden_states, self._current_afd_connector_metadata

View File

@ -1347,19 +1347,13 @@ class DeepseekV2Model(nn.Module):
afd_metadata: AFDMetadata, afd_metadata: AFDMetadata,
llama_4_scaling: torch.Tensor | None = None, llama_4_scaling: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
recv_handle = None
for layer in islice(self.layers, self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
afd_connector = afd_metadata.afd_connector afd_connector = afd_metadata.afd_connector
afd_metadata.afd_stage_idx = dbo_current_ubatch_id() afd_metadata.afd_stage_idx = dbo_current_ubatch_id()
if layer.layer_idx > 0: if layer.layer_idx > 0:
hidden_states, recv_metadata = afd_connector.recv_ffn_output() hidden_states = afd_connector.recv_ffn_output()
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
if recv_handle is not None:
for work in recv_handle:
work.wait()
current_hidden, residual = layer( current_hidden, residual = layer(
positions, hidden_states, residual, llama_4_scaling positions, hidden_states, residual, llama_4_scaling
) )
@ -1377,12 +1371,7 @@ class DeepseekV2Model(nn.Module):
if dbo_enabled(): if dbo_enabled():
dbo_yield() dbo_yield()
hidden_states, recv_metadata = afd_connector.recv_ffn_output() hidden_states = afd_connector.recv_ffn_output()
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
if recv_handle is not None:
for work in recv_handle:
work.wait()
return hidden_states, residual return hidden_states, residual
@ -1395,7 +1384,6 @@ class DeepseekV2Model(nn.Module):
llama_4_scaling: torch.Tensor | None = None, llama_4_scaling: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
forward_conext = get_forward_context() forward_conext = get_forward_context()
recv_handle = None
ubatch_hidden_states = [] ubatch_hidden_states = []
ubatch_residual = [] ubatch_residual = []
@ -1421,16 +1409,10 @@ class DeepseekV2Model(nn.Module):
residual = ubatch_residual[stage_i] residual = ubatch_residual[stage_i]
if layer.layer_idx > 0: if layer.layer_idx > 0:
hidden_states, recv_metadata = afd_connector.recv_ffn_output() hidden_states = afd_connector.recv_ffn_output()
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
else: else:
hidden_states = ubatch_hidden_states[stage_i] hidden_states = ubatch_hidden_states[stage_i]
if recv_handle is not None:
for work in recv_handle:
work.wait()
current_positions = afd_metadata.positions_list[stage_i] current_positions = afd_metadata.positions_list[stage_i]
hidden_states, residual = layer( hidden_states, residual = layer(
current_positions, hidden_states, residual, llama_4_scaling current_positions, hidden_states, residual, llama_4_scaling
@ -1452,9 +1434,7 @@ class DeepseekV2Model(nn.Module):
# Recv last layer FFN output. # Recv last layer FFN output.
for stage_i in range(afd_metadata.num_of_stages): for stage_i in range(afd_metadata.num_of_stages):
ubatch_hidden_states[stage_i], recv_metadata = ( ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output()
afd_connector.recv_ffn_output()
)
# Re-assemble the batch # Re-assemble the batch
hidden_states = torch.cat(ubatch_hidden_states, dim=0) hidden_states = torch.cat(ubatch_hidden_states, dim=0)

View File

@ -385,7 +385,6 @@ class Step3TextModel(nn.Module):
afd_metadata: AFDMetadata, afd_metadata: AFDMetadata,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
forward_conext = get_forward_context() forward_conext = get_forward_context()
recv_handle = None
ubatch_hidden_states = [] ubatch_hidden_states = []
ubatch_residual = [] ubatch_residual = []
@ -409,16 +408,10 @@ class Step3TextModel(nn.Module):
residual = ubatch_residual[stage_i] residual = ubatch_residual[stage_i]
if layer.layer_idx > 0: if layer.layer_idx > 0:
hidden_states, recv_metadata = afd_connector.recv_ffn_output() hidden_states = afd_connector.recv_ffn_output()
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
else: else:
hidden_states = ubatch_hidden_states[stage_i] hidden_states = ubatch_hidden_states[stage_i]
if recv_handle is not None:
for work in recv_handle:
work.wait()
current_positions = afd_metadata.positions_list[stage_i] current_positions = afd_metadata.positions_list[stage_i]
hidden_states, residual = layer( hidden_states, residual = layer(
current_positions, hidden_states, residual current_positions, hidden_states, residual
@ -439,9 +432,7 @@ class Step3TextModel(nn.Module):
# Recv last layer FFN output. # Recv last layer FFN output.
for stage_i in range(afd_metadata.num_of_stages): for stage_i in range(afd_metadata.num_of_stages):
ubatch_hidden_states[stage_i], recv_metadata = ( ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output()
afd_connector.recv_ffn_output()
)
# Re-assemble the batch # Re-assemble the batch
hidden_states = torch.cat(ubatch_hidden_states, dim=0) hidden_states = torch.cat(ubatch_hidden_states, dim=0)