mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 20:30:14 +08:00
Merge pull request #21 from jiangkuaixue123/afd-refactor-p2p-connector
[Refactor] AFD p2pconnector recv_ffn_output
This commit is contained in:
commit
e7254d8994
@ -225,7 +225,7 @@ class P2PAFDConnector(AFDConnectorBase):
|
|||||||
src: int,
|
src: int,
|
||||||
process_group: GroupCoordinator,
|
process_group: GroupCoordinator,
|
||||||
tensor_metadata: TensorMetadata,
|
tensor_metadata: TensorMetadata,
|
||||||
) -> tuple[torch.Tensor, list]:
|
) -> torch.Tensor:
|
||||||
if not torch.distributed.is_initialized() or process_group.world_size == 1:
|
if not torch.distributed.is_initialized() or process_group.world_size == 1:
|
||||||
return {}, []
|
return {}, []
|
||||||
assert src < process_group.world_size, f"Invalid src rank ({src})"
|
assert src < process_group.world_size, f"Invalid src rank ({src})"
|
||||||
@ -240,7 +240,7 @@ class P2PAFDConnector(AFDConnectorBase):
|
|||||||
src=process_group.ranks[src],
|
src=process_group.ranks[src],
|
||||||
group=process_group.device_group,
|
group=process_group.device_group,
|
||||||
)
|
)
|
||||||
return hidden_states, []
|
return hidden_states
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# attn -> ffn
|
# attn -> ffn
|
||||||
@ -262,7 +262,7 @@ class P2PAFDConnector(AFDConnectorBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Communication error: {e}")
|
raise RuntimeError(f"Communication error: {e}")
|
||||||
|
|
||||||
def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]:
|
def recv_ffn_output(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Called by the ATTN side to receive MOE output intermediate tensors,
|
Called by the ATTN side to receive MOE output intermediate tensors,
|
||||||
possibly dispatching from the receiver to other GPUs.
|
possibly dispatching from the receiver to other GPUs.
|
||||||
@ -272,16 +272,15 @@ class P2PAFDConnector(AFDConnectorBase):
|
|||||||
self.recv_ffn_output_counter
|
self.recv_ffn_output_counter
|
||||||
% self._current_afd_connector_metadata.num_of_stages
|
% self._current_afd_connector_metadata.num_of_stages
|
||||||
)
|
)
|
||||||
hidden_states, work_list = self._recv_hidden_states(
|
hidden_states = self._recv_hidden_states(
|
||||||
src,
|
src,
|
||||||
self.e2a_group,
|
self.e2a_group,
|
||||||
self._tensor_metadata_list[stage_idx],
|
self._tensor_metadata_list[stage_idx],
|
||||||
)
|
)
|
||||||
self._current_afd_connector_metadata.recv_handle_list = work_list
|
|
||||||
self.recv_ffn_output_counter = (
|
self.recv_ffn_output_counter = (
|
||||||
self.recv_ffn_output_counter + 1
|
self.recv_ffn_output_counter + 1
|
||||||
) % self._current_afd_connector_metadata.num_of_stages
|
) % self._current_afd_connector_metadata.num_of_stages
|
||||||
return hidden_states, self._current_afd_connector_metadata
|
return hidden_states
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# ffn -> attn
|
# ffn -> attn
|
||||||
@ -328,12 +327,11 @@ class P2PAFDConnector(AFDConnectorBase):
|
|||||||
self.recv_attn_output_counter
|
self.recv_attn_output_counter
|
||||||
// self._current_afd_connector_metadata.num_of_stages
|
// self._current_afd_connector_metadata.num_of_stages
|
||||||
)
|
)
|
||||||
hidden_states, work_list = self._recv_hidden_states(
|
hidden_states = self._recv_hidden_states(
|
||||||
src,
|
src,
|
||||||
self.a2e_group,
|
self.a2e_group,
|
||||||
self._tensor_metadata_list[stage_idx],
|
self._tensor_metadata_list[stage_idx],
|
||||||
)
|
)
|
||||||
self._current_afd_connector_metadata.recv_handle_list = work_list
|
|
||||||
self._current_afd_connector_metadata.layer_idx = layer_idx
|
self._current_afd_connector_metadata.layer_idx = layer_idx
|
||||||
self._current_afd_connector_metadata.stage_idx = stage_idx
|
self._current_afd_connector_metadata.stage_idx = stage_idx
|
||||||
return hidden_states, self._current_afd_connector_metadata
|
return hidden_states, self._current_afd_connector_metadata
|
||||||
|
|||||||
@ -1347,19 +1347,13 @@ class DeepseekV2Model(nn.Module):
|
|||||||
afd_metadata: AFDMetadata,
|
afd_metadata: AFDMetadata,
|
||||||
llama_4_scaling: torch.Tensor | None = None,
|
llama_4_scaling: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
recv_handle = None
|
|
||||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
afd_connector = afd_metadata.afd_connector
|
afd_connector = afd_metadata.afd_connector
|
||||||
afd_metadata.afd_stage_idx = dbo_current_ubatch_id()
|
afd_metadata.afd_stage_idx = dbo_current_ubatch_id()
|
||||||
|
|
||||||
if layer.layer_idx > 0:
|
if layer.layer_idx > 0:
|
||||||
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
|
hidden_states = afd_connector.recv_ffn_output()
|
||||||
if recv_metadata.recv_handle_list is not None:
|
|
||||||
recv_handle = recv_metadata.recv_handle_list
|
|
||||||
|
|
||||||
if recv_handle is not None:
|
|
||||||
for work in recv_handle:
|
|
||||||
work.wait()
|
|
||||||
current_hidden, residual = layer(
|
current_hidden, residual = layer(
|
||||||
positions, hidden_states, residual, llama_4_scaling
|
positions, hidden_states, residual, llama_4_scaling
|
||||||
)
|
)
|
||||||
@ -1377,12 +1371,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
if dbo_enabled():
|
if dbo_enabled():
|
||||||
dbo_yield()
|
dbo_yield()
|
||||||
|
|
||||||
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
|
hidden_states = afd_connector.recv_ffn_output()
|
||||||
if recv_metadata.recv_handle_list is not None:
|
|
||||||
recv_handle = recv_metadata.recv_handle_list
|
|
||||||
if recv_handle is not None:
|
|
||||||
for work in recv_handle:
|
|
||||||
work.wait()
|
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@ -1395,7 +1384,6 @@ class DeepseekV2Model(nn.Module):
|
|||||||
llama_4_scaling: torch.Tensor | None = None,
|
llama_4_scaling: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
forward_conext = get_forward_context()
|
forward_conext = get_forward_context()
|
||||||
recv_handle = None
|
|
||||||
|
|
||||||
ubatch_hidden_states = []
|
ubatch_hidden_states = []
|
||||||
ubatch_residual = []
|
ubatch_residual = []
|
||||||
@ -1421,16 +1409,10 @@ class DeepseekV2Model(nn.Module):
|
|||||||
residual = ubatch_residual[stage_i]
|
residual = ubatch_residual[stage_i]
|
||||||
|
|
||||||
if layer.layer_idx > 0:
|
if layer.layer_idx > 0:
|
||||||
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
|
hidden_states = afd_connector.recv_ffn_output()
|
||||||
if recv_metadata.recv_handle_list is not None:
|
|
||||||
recv_handle = recv_metadata.recv_handle_list
|
|
||||||
else:
|
else:
|
||||||
hidden_states = ubatch_hidden_states[stage_i]
|
hidden_states = ubatch_hidden_states[stage_i]
|
||||||
|
|
||||||
if recv_handle is not None:
|
|
||||||
for work in recv_handle:
|
|
||||||
work.wait()
|
|
||||||
|
|
||||||
current_positions = afd_metadata.positions_list[stage_i]
|
current_positions = afd_metadata.positions_list[stage_i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
current_positions, hidden_states, residual, llama_4_scaling
|
current_positions, hidden_states, residual, llama_4_scaling
|
||||||
@ -1452,9 +1434,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
|
|
||||||
# Recv last layer FFN output.
|
# Recv last layer FFN output.
|
||||||
for stage_i in range(afd_metadata.num_of_stages):
|
for stage_i in range(afd_metadata.num_of_stages):
|
||||||
ubatch_hidden_states[stage_i], recv_metadata = (
|
ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output()
|
||||||
afd_connector.recv_ffn_output()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Re-assemble the batch
|
# Re-assemble the batch
|
||||||
hidden_states = torch.cat(ubatch_hidden_states, dim=0)
|
hidden_states = torch.cat(ubatch_hidden_states, dim=0)
|
||||||
|
|||||||
@ -385,7 +385,6 @@ class Step3TextModel(nn.Module):
|
|||||||
afd_metadata: AFDMetadata,
|
afd_metadata: AFDMetadata,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
forward_conext = get_forward_context()
|
forward_conext = get_forward_context()
|
||||||
recv_handle = None
|
|
||||||
|
|
||||||
ubatch_hidden_states = []
|
ubatch_hidden_states = []
|
||||||
ubatch_residual = []
|
ubatch_residual = []
|
||||||
@ -409,16 +408,10 @@ class Step3TextModel(nn.Module):
|
|||||||
residual = ubatch_residual[stage_i]
|
residual = ubatch_residual[stage_i]
|
||||||
|
|
||||||
if layer.layer_idx > 0:
|
if layer.layer_idx > 0:
|
||||||
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
|
hidden_states = afd_connector.recv_ffn_output()
|
||||||
if recv_metadata.recv_handle_list is not None:
|
|
||||||
recv_handle = recv_metadata.recv_handle_list
|
|
||||||
else:
|
else:
|
||||||
hidden_states = ubatch_hidden_states[stage_i]
|
hidden_states = ubatch_hidden_states[stage_i]
|
||||||
|
|
||||||
if recv_handle is not None:
|
|
||||||
for work in recv_handle:
|
|
||||||
work.wait()
|
|
||||||
|
|
||||||
current_positions = afd_metadata.positions_list[stage_i]
|
current_positions = afd_metadata.positions_list[stage_i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
current_positions, hidden_states, residual
|
current_positions, hidden_states, residual
|
||||||
@ -439,9 +432,7 @@ class Step3TextModel(nn.Module):
|
|||||||
|
|
||||||
# Recv last layer FFN output.
|
# Recv last layer FFN output.
|
||||||
for stage_i in range(afd_metadata.num_of_stages):
|
for stage_i in range(afd_metadata.num_of_stages):
|
||||||
ubatch_hidden_states[stage_i], recv_metadata = (
|
ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output()
|
||||||
afd_connector.recv_ffn_output()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Re-assemble the batch
|
# Re-assemble the batch
|
||||||
hidden_states = torch.cat(ubatch_hidden_states, dim=0)
|
hidden_states = torch.cat(ubatch_hidden_states, dim=0)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user