diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 6355573a68774..75a38805ef5aa 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -398,43 +398,72 @@ class Step3TextModel(nn.Module): positions: torch.Tensor, afd_metadata: AFDMetadata, ) -> tuple[torch.Tensor, torch.Tensor]: + forward_conext = get_forward_context() recv_handle = None - logger.info(f"{__file__}: forward with afd called, may blocked here") - for layer in islice(self.layers, self.start_layer, self.end_layer): - afd_connector = afd_metadata.afd_connector - afd_metadata.afd_stage_idx = dbo_current_ubatch_id() - if layer.layer_idx > 0: - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list + ubatch_hidden_states = [] + ubatch_residual = [] - if recv_handle is not None: - for work in recv_handle: - work.wait() - logger.info(f"Step3TextModel {layer.layer_idx=}: {hidden_states.shape=}, {positions.shape=}") - current_hidden, residual = layer(positions, hidden_states, residual) - logger.info(f"create attn metadata: {current_hidden.shape=}") - metadata = AFDConnectorMetadata.create_attention_metadata( - layer_idx=layer.layer_idx, - stage_idx=afd_metadata.afd_stage_idx, - seq_len=current_hidden.shape[0], - dtype=current_hidden.dtype, - device=current_hidden.device, - num_of_stages=afd_metadata.num_of_stages, - afd_tokens_lens=afd_metadata.afd_tokens_lens, + start_idx = 0 + for pos in afd_metadata.positions_list: + num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0] + end_idx = start_idx + num_tokens + ubatch_hidden_states.append(hidden_states[start_idx:end_idx]) + ubatch_residual.append( + residual[start_idx:end_idx] if residual is not None else None ) - afd_connector.send_attn_output(current_hidden, metadata) + start_idx = end_idx - if dbo_enabled(): - dbo_yield() + for layer in islice(self.layers, self.start_layer, self.end_layer): + for stage_i in range(forward_conext.afd_metadata.num_of_stages): + afd_connector = afd_metadata.afd_connector + forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i] + forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i] - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list - if recv_handle is not None: - for work in recv_handle: - work.wait() + residual = ubatch_residual[stage_i] + + if layer.layer_idx > 0: + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + else: + hidden_states = ubatch_hidden_states[stage_i] + + if recv_handle is not None: + for work in recv_handle: + work.wait() + + current_positions = afd_metadata.positions_list[stage_i] + hidden_states, residual = layer( + current_positions, hidden_states, residual + ) + + ubatch_hidden_states[stage_i] = hidden_states + ubatch_residual[stage_i] = residual + logger.info(f"create attn metadata:, {afd_metadata.afd_tokens_lens=}") + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=stage_i, + seq_len=hidden_states.shape[0], + dtype=hidden_states.dtype, + device=hidden_states.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(hidden_states, metadata) + + # Recv last layer FFN output. + for stage_i in range(afd_metadata.num_of_stages): + ubatch_hidden_states[stage_i], recv_metadata = ( + afd_connector.recv_ffn_output() + ) + + # Re-assemble the batch + hidden_states = torch.cat(ubatch_hidden_states, dim=0) + if any(r is not None for r in ubatch_residual): + residual = torch.cat(ubatch_residual, dim=0) + else: + residual = None return hidden_states, residual