[Chore] bring back deleted code

This commit is contained in:
i-yuanyukun 2025-12-19 16:03:47 +08:00
parent 11d7d5bf59
commit 65ea10c8f4

View File

@ -15,6 +15,10 @@ from vllm.distributed.parallel_state import (
init_model_parallel_group,
)
from vllm.logger import init_logger
from vllm.forward_context import (
DPMetadata,
get_forward_context,
)
from .base import AFDConnectorBase
from .metadata import AFDConnectorMetadata
@ -59,6 +63,7 @@ class P2PAFDConnector(AFDConnectorBase):
self.recv_attn_output_counter: int = 0
self.recv_ffn_output_counter: int = 0
self.dp_metadata_list: dict[int, DPMetadata] = {}
def close(self) -> None:
"""Close the connector and release resources."""
@ -175,6 +180,19 @@ class P2PAFDConnector(AFDConnectorBase):
self._tensor_metadata_list = self._build_tensor_metadata_list(
tensor_metadata, self._current_afd_connector_metadata
)
logger.info(f"{self.config.parallel_config.data_parallel_size=}")
if self.config.parallel_config.data_parallel_size > 1:
logger.info("jcz recv_metadata num_of_stages:{}".format(self._current_afd_connector_metadata.num_of_stages))
for stage_idx in range(self._current_afd_connector_metadata.num_of_stages):
num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0]
logger.info(f"{stage_idx=}, {num_tokens_per_ubatch=}")
self.dp_metadata_list[stage_idx] = DPMetadata.make(
self.config.parallel_config,
num_tokens_per_ubatch,
torch.tensor([num_tokens_per_ubatch] * self.config.parallel_config.data_parallel_size,
device="cpu", dtype=torch.int32),
)
logger.info("jcz recv_metadata self.dp_metadata_list:{}".format(self.dp_metadata_list))
def _send_hidden_states(
self,
@ -307,4 +325,5 @@ class P2PAFDConnector(AFDConnectorBase):
)
self._current_afd_connector_metadata.recv_handle_list = work_list
self._current_afd_connector_metadata.layer_idx = layer_idx
self._current_afd_connector_metadata.stage_idx = stage_idx
return hidden_states, self._current_afd_connector_metadata