mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 19:25:01 +08:00
[NIXL][BUG FIX] Fix a bug for PD with host_buffer after merging 29665 (#30420)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Signed-off-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
9e33a1a75b
commit
ae2e503dda
@ -461,7 +461,7 @@ class TestNixlHandshake:
|
||||
metadata = NixlConnectorMetadata()
|
||||
if num_xfers > 0:
|
||||
num_xfers -= 1
|
||||
metadata.add_new_req(
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id=request_id,
|
||||
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3],
|
||||
kv_transfer_params={
|
||||
@ -532,7 +532,7 @@ class TestNixlHandshake:
|
||||
vllm_config, connector.engine_id
|
||||
)
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id="id",
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
@ -588,7 +588,7 @@ class TestNixlHandshake:
|
||||
metadata = NixlConnectorMetadata()
|
||||
total_reqs = 5
|
||||
for i in range(total_reqs):
|
||||
metadata.add_new_req(
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id=f"id_{i}",
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
@ -752,7 +752,7 @@ def test_kv_connector_stats(dist_init):
|
||||
# Create transfer metadata
|
||||
request_id = "test_req_for_stats"
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id=request_id,
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
@ -1515,7 +1515,7 @@ def test_handshake_failure_returns_finished(dist_init):
|
||||
|
||||
request_id = "test_handshake_fail"
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id=request_id,
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
@ -1565,7 +1565,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
|
||||
|
||||
request_id = "test_transfer_fail"
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id=request_id,
|
||||
local_block_ids=[7, 8, 9],
|
||||
kv_transfer_params={
|
||||
|
||||
@ -202,17 +202,22 @@ def compute_nixl_compatibility_hash(
|
||||
return compat_hash
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteMeta:
|
||||
block_ids: list[int]
|
||||
host: str
|
||||
port: int
|
||||
engine_id: str
|
||||
request_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
# To be used when logical block size does not match the kernel block size
|
||||
local_physical_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
remote_engine_id: str
|
||||
remote_request_id: str
|
||||
tp_size: int
|
||||
remote: RemoteMeta | None = None
|
||||
|
||||
|
||||
class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
@ -223,31 +228,43 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
self.reqs_in_batch: set[ReqId] = set()
|
||||
self.reqs_not_processed: set[ReqId] = set()
|
||||
|
||||
def add_new_req(
|
||||
def _add_new_req(
|
||||
self,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
) -> ReqMeta:
|
||||
return ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
local_physical_block_ids=local_block_ids,
|
||||
# P workers don't need to receive tp_size from proxy here.
|
||||
tp_size=kv_transfer_params.get("tp_size", 1),
|
||||
)
|
||||
|
||||
def add_new_req_to_save(
|
||||
self,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
load_remote_cache: bool = True,
|
||||
save_to_host: bool = False,
|
||||
):
|
||||
# save and load are mutually exclusive
|
||||
assert load_remote_cache ^ save_to_host
|
||||
_req = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
local_physical_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_request_id=kv_transfer_params["remote_request_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
# P workers don't need to receive tp_size from proxy here.
|
||||
tp_size=kv_transfer_params.get("tp_size", 1),
|
||||
self.reqs_to_save[request_id] = self._add_new_req(
|
||||
local_block_ids, kv_transfer_params
|
||||
)
|
||||
if save_to_host:
|
||||
self.reqs_to_save[request_id] = _req
|
||||
if load_remote_cache:
|
||||
self.reqs_to_recv[request_id] = _req
|
||||
|
||||
def add_new_req_to_recv(
|
||||
self,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
):
|
||||
req = self._add_new_req(local_block_ids, kv_transfer_params)
|
||||
req.remote = RemoteMeta(
|
||||
block_ids=kv_transfer_params["remote_block_ids"],
|
||||
engine_id=kv_transfer_params["remote_engine_id"],
|
||||
request_id=kv_transfer_params["remote_request_id"],
|
||||
host=kv_transfer_params["remote_host"],
|
||||
port=kv_transfer_params["remote_port"],
|
||||
)
|
||||
self.reqs_to_recv[request_id] = req
|
||||
|
||||
|
||||
class NixlConnector(KVConnectorBase_V1):
|
||||
@ -666,22 +683,18 @@ class NixlConnectorScheduler:
|
||||
# Loop through scheduled reqs and convert to ReqMeta.
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req(
|
||||
meta.add_new_req_to_recv(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
load_remote_cache=True,
|
||||
save_to_host=False,
|
||||
)
|
||||
|
||||
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req(
|
||||
meta.add_new_req_to_save(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
load_remote_cache=False,
|
||||
save_to_host=True,
|
||||
)
|
||||
|
||||
meta.reqs_to_send = self._reqs_need_send
|
||||
@ -1124,10 +1137,11 @@ class NixlConnectorWorker:
|
||||
# Do NIXL handshake in background and add to _ready_requests when done.
|
||||
fut = self._handshake_futures.get(remote_engine_id)
|
||||
if fut is None:
|
||||
assert meta.remote is not None
|
||||
fut = self._handshake_initiation_executor.submit(
|
||||
self._nixl_handshake,
|
||||
meta.remote_host,
|
||||
meta.remote_port,
|
||||
meta.remote.host,
|
||||
meta.remote.port,
|
||||
meta.tp_size,
|
||||
remote_engine_id,
|
||||
)
|
||||
@ -1774,6 +1788,7 @@ class NixlConnectorWorker:
|
||||
# clean up metadata for completed requests
|
||||
meta = self._recving_metadata.pop(req_id, None)
|
||||
assert meta is not None, f"{req_id} not found in recving_metadata list"
|
||||
assert meta.remote is not None
|
||||
if self.use_host_buffer:
|
||||
self.sync_recved_kv_to_device(req_id, meta)
|
||||
if self.enable_permute_local_kv:
|
||||
@ -1781,7 +1796,7 @@ class NixlConnectorWorker:
|
||||
|
||||
# post processing for heteroblocksize
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
|
||||
meta.remote_engine_id
|
||||
meta.remote.engine_id
|
||||
)
|
||||
if (
|
||||
not self.use_mla
|
||||
@ -1916,17 +1931,18 @@ class NixlConnectorWorker:
|
||||
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
|
||||
meta.local_block_ids
|
||||
)
|
||||
meta.remote_block_ids = self._logical_to_kernel_block_ids(
|
||||
meta.remote_block_ids
|
||||
assert meta.remote is not None
|
||||
meta.remote.block_ids = self._logical_to_kernel_block_ids(
|
||||
meta.remote.block_ids
|
||||
)
|
||||
remote_engine_id = meta.remote_engine_id
|
||||
remote_engine_id = meta.remote.engine_id
|
||||
logger.debug(
|
||||
"start_load_kv for request %s from remote engine %s. "
|
||||
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
|
||||
req_id,
|
||||
remote_engine_id,
|
||||
len(meta.local_physical_block_ids),
|
||||
len(meta.remote_block_ids),
|
||||
len(meta.remote.block_ids),
|
||||
)
|
||||
# always store metadata for failure recovery
|
||||
self._recving_metadata[req_id] = meta
|
||||
@ -1965,17 +1981,18 @@ class NixlConnectorWorker:
|
||||
self._reqs_to_send[req_id] = expiration_time
|
||||
|
||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||
assert meta.remote is not None
|
||||
logger.debug(
|
||||
"Remote agent %s available, calling _read_blocks for req %s",
|
||||
meta.remote_engine_id,
|
||||
meta.remote.engine_id,
|
||||
req_id,
|
||||
)
|
||||
self._read_blocks(
|
||||
request_id=req_id,
|
||||
dst_engine_id=meta.remote_engine_id,
|
||||
remote_request_id=meta.remote_request_id,
|
||||
dst_engine_id=meta.remote.engine_id,
|
||||
remote_request_id=meta.remote.request_id,
|
||||
local_block_ids=meta.local_physical_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
remote_block_ids=meta.remote.block_ids,
|
||||
)
|
||||
|
||||
def _read_blocks(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user