mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 20:22:15 +08:00
[Bugfix][V1][P/D]Fix the issue where repeated requests for the same input produce abnormal outputs for P2pNcclConnector (#23403)
Signed-off-by: Abatom <abzhonghua@gmail.com>
This commit is contained in:
parent
8a3cd90af5
commit
9188ae7cb5
@ -245,16 +245,33 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
assert self.p2p_nccl_engine is not None
|
assert self.p2p_nccl_engine is not None
|
||||||
|
|
||||||
|
def extract_kv_from_layer(
|
||||||
|
layer: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Extract the KV cache from the layer.
|
||||||
|
|
||||||
|
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||||
|
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||||
|
"""
|
||||||
|
if isinstance(attn_metadata, MLACommonMetadata):
|
||||||
|
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||||
|
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
|
||||||
|
...]
|
||||||
|
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||||
|
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
|
||||||
|
...]
|
||||||
|
|
||||||
connector_metadata = self._get_connector_metadata()
|
connector_metadata = self._get_connector_metadata()
|
||||||
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
|
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
|
||||||
for request in connector_metadata.requests:
|
for request in connector_metadata.requests:
|
||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
ip, port = self.parse_request_id(request_id, True)
|
ip, port = self.parse_request_id(request_id, True)
|
||||||
remote_address = ip + ":" + str(port + self._rank)
|
remote_address = ip + ":" + str(port + self._rank)
|
||||||
self.p2p_nccl_engine.send_tensor(
|
|
||||||
request_id + "#" + layer_name, kv_layer, remote_address,
|
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
|
||||||
request.slot_mapping,
|
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
|
||||||
isinstance(attn_metadata, MLACommonMetadata))
|
kv_cache, remote_address)
|
||||||
|
|
||||||
def wait_for_save(self):
|
def wait_for_save(self):
|
||||||
if self.is_producer:
|
if self.is_producer:
|
||||||
|
|||||||
@ -62,8 +62,6 @@ class SendQueueItem:
|
|||||||
tensor_id: str
|
tensor_id: str
|
||||||
remote_address: str
|
remote_address: str
|
||||||
tensor: torch.Tensor
|
tensor: torch.Tensor
|
||||||
slot_mapping: torch.Tensor
|
|
||||||
is_mla: bool
|
|
||||||
|
|
||||||
|
|
||||||
class P2pNcclEngine:
|
class P2pNcclEngine:
|
||||||
@ -202,8 +200,6 @@ class P2pNcclEngine:
|
|||||||
tensor_id: str,
|
tensor_id: str,
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
remote_address: typing.Optional[str] = None,
|
remote_address: typing.Optional[str] = None,
|
||||||
slot_mapping: torch.Tensor = None,
|
|
||||||
is_mla: bool = False,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if remote_address is None:
|
if remote_address is None:
|
||||||
with self.recv_store_cv:
|
with self.recv_store_cv:
|
||||||
@ -213,9 +209,7 @@ class P2pNcclEngine:
|
|||||||
|
|
||||||
item = SendQueueItem(tensor_id=tensor_id,
|
item = SendQueueItem(tensor_id=tensor_id,
|
||||||
remote_address=remote_address,
|
remote_address=remote_address,
|
||||||
tensor=tensor,
|
tensor=tensor)
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
is_mla=is_mla)
|
|
||||||
|
|
||||||
if self.send_type == "PUT":
|
if self.send_type == "PUT":
|
||||||
return self.send_sync(item)
|
return self.send_sync(item)
|
||||||
@ -433,9 +427,7 @@ class P2pNcclEngine:
|
|||||||
if item.remote_address not in self.socks:
|
if item.remote_address not in self.socks:
|
||||||
self.create_connect(item.remote_address)
|
self.create_connect(item.remote_address)
|
||||||
|
|
||||||
with self.send_stream:
|
tensor = item.tensor
|
||||||
tensor = self.extract_kv_from_layer(item.is_mla, item.tensor,
|
|
||||||
item.slot_mapping)
|
|
||||||
|
|
||||||
sock = self.socks[item.remote_address]
|
sock = self.socks[item.remote_address]
|
||||||
comm, rank = self.comms[item.remote_address]
|
comm, rank = self.comms[item.remote_address]
|
||||||
@ -548,21 +540,3 @@ class P2pNcclEngine:
|
|||||||
self._send_thread.join()
|
self._send_thread.join()
|
||||||
if self._ping_thread is not None:
|
if self._ping_thread is not None:
|
||||||
self._ping_thread.join()
|
self._ping_thread.join()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_kv_from_layer(
|
|
||||||
is_mla: bool,
|
|
||||||
layer: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Extract the KV cache from the layer.
|
|
||||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
|
||||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
|
||||||
"""
|
|
||||||
if is_mla:
|
|
||||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
|
||||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
|
|
||||||
|
|
||||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
|
||||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
|
|
||||||
...]
|
|
||||||
|
|||||||
@ -99,8 +99,9 @@ class TensorMemoryPool:
|
|||||||
addr=self.base_address)
|
addr=self.base_address)
|
||||||
self.free_lists[self.max_block_size][
|
self.free_lists[self.max_block_size][
|
||||||
initial_block.addr] = initial_block
|
initial_block.addr] = initial_block
|
||||||
logger.debug("TensorMemoryPool, base_address:", self.base_address,
|
|
||||||
self.base_address % self.max_block_size)
|
logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d",
|
||||||
|
self.base_address, self.max_block_size)
|
||||||
|
|
||||||
def allocate(self, size: int) -> int:
|
def allocate(self, size: int) -> int:
|
||||||
"""Allocates a memory block of at least the requested size.
|
"""Allocates a memory block of at least the requested size.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user