[Bugfix][V1][P/D]Fix the issue where repeated requests for the same input produce abnormal outputs for P2pNcclConnector (#23403)

Signed-off-by: Abatom <abzhonghua@gmail.com>
This commit is contained in:
Zhonghua Deng 2025-08-26 03:57:08 +08:00 committed by GitHub
parent 8a3cd90af5
commit 9188ae7cb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 34 deletions

View File

@ -245,16 +245,33 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
self.p2p_nccl_engine.send_tensor(
request_id + "#" + layer_name, kv_layer, remote_address,
request.slot_mapping,
isinstance(attn_metadata, MLACommonMetadata))
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
def wait_for_save(self):
if self.is_producer:

View File

@ -62,8 +62,6 @@ class SendQueueItem:
tensor_id: str
remote_address: str
tensor: torch.Tensor
slot_mapping: torch.Tensor
is_mla: bool
class P2pNcclEngine:
@ -202,8 +200,6 @@ class P2pNcclEngine:
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
slot_mapping: torch.Tensor = None,
is_mla: bool = False,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
@ -213,9 +209,7 @@ class P2pNcclEngine:
item = SendQueueItem(tensor_id=tensor_id,
remote_address=remote_address,
tensor=tensor,
slot_mapping=slot_mapping,
is_mla=is_mla)
tensor=tensor)
if self.send_type == "PUT":
return self.send_sync(item)
@ -433,9 +427,7 @@ class P2pNcclEngine:
if item.remote_address not in self.socks:
self.create_connect(item.remote_address)
with self.send_stream:
tensor = self.extract_kv_from_layer(item.is_mla, item.tensor,
item.slot_mapping)
tensor = item.tensor
sock = self.socks[item.remote_address]
comm, rank = self.comms[item.remote_address]
@ -548,21 +540,3 @@ class P2pNcclEngine:
self._send_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()
@staticmethod
def extract_kv_from_layer(
is_mla: bool,
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if is_mla:
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]

View File

@ -99,8 +99,9 @@ class TensorMemoryPool:
addr=self.base_address)
self.free_lists[self.max_block_size][
initial_block.addr] = initial_block
logger.debug("TensorMemoryPool, base_address:", self.base_address,
self.base_address % self.max_block_size)
logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d",
self.base_address, self.max_block_size)
def allocate(self, size: int) -> int:
"""Allocates a memory block of at least the requested size.