[Bugfix] Fix several issues with p2p xPyD in GET type (#23993)

Signed-off-by: Csrayz <jover@cmbchina.com>
Signed-off-by: ivyilike <pww123@cmbchina.com>
Co-authored-by: ivyilike <pww123@cmbchina.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Csrayz 2025-09-22 22:53:13 +08:00 committed by yewentao256
parent cc494282a9
commit 4057e2b162
2 changed files with 21 additions and 10 deletions

View File

@ -178,6 +178,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
# Load the KV for each request each layer
for request in metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank)
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
@ -191,7 +194,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
layer = kv_cache[forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name)
request.request_id + "#" + layer_name, remote_address)
if kv_cache is None:
logger.warning("🚧kv_cache is None, %s", request.request_id)

View File

@ -134,7 +134,6 @@ class P2pNcclEngine:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[SendQueueItem] = deque()
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self.send_async,
daemon=True)
@ -143,6 +142,7 @@ class P2pNcclEngine:
# tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {}
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.socks: dict[str, Any] = {} # remote_address: client socket
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
@ -223,18 +223,26 @@ class P2pNcclEngine:
# GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
if tensor_size > self.buffer_size_threshold:
logger.warning(
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
"buffer size threshold :%d, skip send to %s, rank:%d",
tensor_id, tensor_size, self.buffer_size_threshold,
remote_address, self.rank)
return False
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
assert len(self.send_store) > 0
oldest_tensor_id = next(iter(self.send_store))
oldest_tensor = self.send_store.pop(oldest_tensor_id)
oldest_tensor_size = oldest_tensor.element_size(
) * oldest_tensor.numel()
self.buffer_size -= oldest_tensor_size
logger.debug(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, self.buffer_size,
oldest_tenser_size, self.rank)
oldest_tensor_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size