mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 11:29:09 +08:00
[Chore]: step3 forward_with_afd
This commit is contained in:
parent
6d305dda38
commit
2a98ab3c8e
@ -398,43 +398,72 @@ class Step3TextModel(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
afd_metadata: AFDMetadata,
|
afd_metadata: AFDMetadata,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
forward_conext = get_forward_context()
|
||||||
recv_handle = None
|
recv_handle = None
|
||||||
logger.info(f"{__file__}: forward with afd called, may blocked here")
|
|
||||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
|
||||||
afd_connector = afd_metadata.afd_connector
|
|
||||||
afd_metadata.afd_stage_idx = dbo_current_ubatch_id()
|
|
||||||
|
|
||||||
if layer.layer_idx > 0:
|
ubatch_hidden_states = []
|
||||||
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
|
ubatch_residual = []
|
||||||
if recv_metadata.recv_handle_list is not None:
|
|
||||||
recv_handle = recv_metadata.recv_handle_list
|
|
||||||
|
|
||||||
if recv_handle is not None:
|
start_idx = 0
|
||||||
for work in recv_handle:
|
for pos in afd_metadata.positions_list:
|
||||||
work.wait()
|
num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0]
|
||||||
logger.info(f"Step3TextModel {layer.layer_idx=}: {hidden_states.shape=}, {positions.shape=}")
|
end_idx = start_idx + num_tokens
|
||||||
current_hidden, residual = layer(positions, hidden_states, residual)
|
ubatch_hidden_states.append(hidden_states[start_idx:end_idx])
|
||||||
logger.info(f"create attn metadata: {current_hidden.shape=}")
|
ubatch_residual.append(
|
||||||
metadata = AFDConnectorMetadata.create_attention_metadata(
|
residual[start_idx:end_idx] if residual is not None else None
|
||||||
layer_idx=layer.layer_idx,
|
|
||||||
stage_idx=afd_metadata.afd_stage_idx,
|
|
||||||
seq_len=current_hidden.shape[0],
|
|
||||||
dtype=current_hidden.dtype,
|
|
||||||
device=current_hidden.device,
|
|
||||||
num_of_stages=afd_metadata.num_of_stages,
|
|
||||||
afd_tokens_lens=afd_metadata.afd_tokens_lens,
|
|
||||||
)
|
)
|
||||||
afd_connector.send_attn_output(current_hidden, metadata)
|
start_idx = end_idx
|
||||||
|
|
||||||
if dbo_enabled():
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
dbo_yield()
|
for stage_i in range(forward_conext.afd_metadata.num_of_stages):
|
||||||
|
afd_connector = afd_metadata.afd_connector
|
||||||
|
forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i]
|
||||||
|
forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i]
|
||||||
|
|
||||||
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
|
residual = ubatch_residual[stage_i]
|
||||||
if recv_metadata.recv_handle_list is not None:
|
|
||||||
recv_handle = recv_metadata.recv_handle_list
|
if layer.layer_idx > 0:
|
||||||
if recv_handle is not None:
|
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
|
||||||
for work in recv_handle:
|
if recv_metadata.recv_handle_list is not None:
|
||||||
work.wait()
|
recv_handle = recv_metadata.recv_handle_list
|
||||||
|
else:
|
||||||
|
hidden_states = ubatch_hidden_states[stage_i]
|
||||||
|
|
||||||
|
if recv_handle is not None:
|
||||||
|
for work in recv_handle:
|
||||||
|
work.wait()
|
||||||
|
|
||||||
|
current_positions = afd_metadata.positions_list[stage_i]
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
current_positions, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
|
ubatch_hidden_states[stage_i] = hidden_states
|
||||||
|
ubatch_residual[stage_i] = residual
|
||||||
|
logger.info(f"create attn metadata:, {afd_metadata.afd_tokens_lens=}")
|
||||||
|
metadata = AFDConnectorMetadata.create_attention_metadata(
|
||||||
|
layer_idx=layer.layer_idx,
|
||||||
|
stage_idx=stage_i,
|
||||||
|
seq_len=hidden_states.shape[0],
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
num_of_stages=afd_metadata.num_of_stages,
|
||||||
|
afd_tokens_lens=afd_metadata.afd_tokens_lens,
|
||||||
|
)
|
||||||
|
afd_connector.send_attn_output(hidden_states, metadata)
|
||||||
|
|
||||||
|
# Recv last layer FFN output.
|
||||||
|
for stage_i in range(afd_metadata.num_of_stages):
|
||||||
|
ubatch_hidden_states[stage_i], recv_metadata = (
|
||||||
|
afd_connector.recv_ffn_output()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-assemble the batch
|
||||||
|
hidden_states = torch.cat(ubatch_hidden_states, dim=0)
|
||||||
|
if any(r is not None for r in ubatch_residual):
|
||||||
|
residual = torch.cat(ubatch_residual, dim=0)
|
||||||
|
else:
|
||||||
|
residual = None
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user