mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 17:15:47 +08:00
[NIXL][BUG FIX] Fix a bug for PD with host_buffer after merging 29665 (#30420)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Signed-off-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
9e33a1a75b
commit
ae2e503dda
@ -461,7 +461,7 @@ class TestNixlHandshake:
|
|||||||
metadata = NixlConnectorMetadata()
|
metadata = NixlConnectorMetadata()
|
||||||
if num_xfers > 0:
|
if num_xfers > 0:
|
||||||
num_xfers -= 1
|
num_xfers -= 1
|
||||||
metadata.add_new_req(
|
metadata.add_new_req_to_recv(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3],
|
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3],
|
||||||
kv_transfer_params={
|
kv_transfer_params={
|
||||||
@ -532,7 +532,7 @@ class TestNixlHandshake:
|
|||||||
vllm_config, connector.engine_id
|
vllm_config, connector.engine_id
|
||||||
)
|
)
|
||||||
metadata = NixlConnectorMetadata()
|
metadata = NixlConnectorMetadata()
|
||||||
metadata.add_new_req(
|
metadata.add_new_req_to_recv(
|
||||||
request_id="id",
|
request_id="id",
|
||||||
local_block_ids=[1, 2, 3],
|
local_block_ids=[1, 2, 3],
|
||||||
kv_transfer_params={
|
kv_transfer_params={
|
||||||
@ -588,7 +588,7 @@ class TestNixlHandshake:
|
|||||||
metadata = NixlConnectorMetadata()
|
metadata = NixlConnectorMetadata()
|
||||||
total_reqs = 5
|
total_reqs = 5
|
||||||
for i in range(total_reqs):
|
for i in range(total_reqs):
|
||||||
metadata.add_new_req(
|
metadata.add_new_req_to_recv(
|
||||||
request_id=f"id_{i}",
|
request_id=f"id_{i}",
|
||||||
local_block_ids=[1, 2, 3],
|
local_block_ids=[1, 2, 3],
|
||||||
kv_transfer_params={
|
kv_transfer_params={
|
||||||
@ -752,7 +752,7 @@ def test_kv_connector_stats(dist_init):
|
|||||||
# Create transfer metadata
|
# Create transfer metadata
|
||||||
request_id = "test_req_for_stats"
|
request_id = "test_req_for_stats"
|
||||||
metadata = NixlConnectorMetadata()
|
metadata = NixlConnectorMetadata()
|
||||||
metadata.add_new_req(
|
metadata.add_new_req_to_recv(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
local_block_ids=[1, 2, 3],
|
local_block_ids=[1, 2, 3],
|
||||||
kv_transfer_params={
|
kv_transfer_params={
|
||||||
@ -1515,7 +1515,7 @@ def test_handshake_failure_returns_finished(dist_init):
|
|||||||
|
|
||||||
request_id = "test_handshake_fail"
|
request_id = "test_handshake_fail"
|
||||||
metadata = NixlConnectorMetadata()
|
metadata = NixlConnectorMetadata()
|
||||||
metadata.add_new_req(
|
metadata.add_new_req_to_recv(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
local_block_ids=[1, 2, 3],
|
local_block_ids=[1, 2, 3],
|
||||||
kv_transfer_params={
|
kv_transfer_params={
|
||||||
@ -1565,7 +1565,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
|
|||||||
|
|
||||||
request_id = "test_transfer_fail"
|
request_id = "test_transfer_fail"
|
||||||
metadata = NixlConnectorMetadata()
|
metadata = NixlConnectorMetadata()
|
||||||
metadata.add_new_req(
|
metadata.add_new_req_to_recv(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
local_block_ids=[7, 8, 9],
|
local_block_ids=[7, 8, 9],
|
||||||
kv_transfer_params={
|
kv_transfer_params={
|
||||||
|
|||||||
@ -202,17 +202,22 @@ def compute_nixl_compatibility_hash(
|
|||||||
return compat_hash
|
return compat_hash
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RemoteMeta:
|
||||||
|
block_ids: list[int]
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
engine_id: str
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReqMeta:
|
class ReqMeta:
|
||||||
local_block_ids: list[int]
|
local_block_ids: list[int]
|
||||||
# To be used when logical block size does not match the kernel block size
|
# To be used when logical block size does not match the kernel block size
|
||||||
local_physical_block_ids: list[int]
|
local_physical_block_ids: list[int]
|
||||||
remote_block_ids: list[int]
|
|
||||||
remote_host: str
|
|
||||||
remote_port: int
|
|
||||||
remote_engine_id: str
|
|
||||||
remote_request_id: str
|
|
||||||
tp_size: int
|
tp_size: int
|
||||||
|
remote: RemoteMeta | None = None
|
||||||
|
|
||||||
|
|
||||||
class NixlConnectorMetadata(KVConnectorMetadata):
|
class NixlConnectorMetadata(KVConnectorMetadata):
|
||||||
@ -223,31 +228,43 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
|||||||
self.reqs_in_batch: set[ReqId] = set()
|
self.reqs_in_batch: set[ReqId] = set()
|
||||||
self.reqs_not_processed: set[ReqId] = set()
|
self.reqs_not_processed: set[ReqId] = set()
|
||||||
|
|
||||||
def add_new_req(
|
def _add_new_req(
|
||||||
|
self,
|
||||||
|
local_block_ids: list[int],
|
||||||
|
kv_transfer_params: dict[str, Any],
|
||||||
|
) -> ReqMeta:
|
||||||
|
return ReqMeta(
|
||||||
|
local_block_ids=local_block_ids,
|
||||||
|
local_physical_block_ids=local_block_ids,
|
||||||
|
# P workers don't need to receive tp_size from proxy here.
|
||||||
|
tp_size=kv_transfer_params.get("tp_size", 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_new_req_to_save(
|
||||||
self,
|
self,
|
||||||
request_id: ReqId,
|
request_id: ReqId,
|
||||||
local_block_ids: list[int],
|
local_block_ids: list[int],
|
||||||
kv_transfer_params: dict[str, Any],
|
kv_transfer_params: dict[str, Any],
|
||||||
load_remote_cache: bool = True,
|
|
||||||
save_to_host: bool = False,
|
|
||||||
):
|
):
|
||||||
# save and load are mutually exclusive
|
self.reqs_to_save[request_id] = self._add_new_req(
|
||||||
assert load_remote_cache ^ save_to_host
|
local_block_ids, kv_transfer_params
|
||||||
_req = ReqMeta(
|
|
||||||
local_block_ids=local_block_ids,
|
|
||||||
local_physical_block_ids=local_block_ids,
|
|
||||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
|
||||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
|
||||||
remote_request_id=kv_transfer_params["remote_request_id"],
|
|
||||||
remote_host=kv_transfer_params["remote_host"],
|
|
||||||
remote_port=kv_transfer_params["remote_port"],
|
|
||||||
# P workers don't need to receive tp_size from proxy here.
|
|
||||||
tp_size=kv_transfer_params.get("tp_size", 1),
|
|
||||||
)
|
)
|
||||||
if save_to_host:
|
|
||||||
self.reqs_to_save[request_id] = _req
|
def add_new_req_to_recv(
|
||||||
if load_remote_cache:
|
self,
|
||||||
self.reqs_to_recv[request_id] = _req
|
request_id: ReqId,
|
||||||
|
local_block_ids: list[int],
|
||||||
|
kv_transfer_params: dict[str, Any],
|
||||||
|
):
|
||||||
|
req = self._add_new_req(local_block_ids, kv_transfer_params)
|
||||||
|
req.remote = RemoteMeta(
|
||||||
|
block_ids=kv_transfer_params["remote_block_ids"],
|
||||||
|
engine_id=kv_transfer_params["remote_engine_id"],
|
||||||
|
request_id=kv_transfer_params["remote_request_id"],
|
||||||
|
host=kv_transfer_params["remote_host"],
|
||||||
|
port=kv_transfer_params["remote_port"],
|
||||||
|
)
|
||||||
|
self.reqs_to_recv[request_id] = req
|
||||||
|
|
||||||
|
|
||||||
class NixlConnector(KVConnectorBase_V1):
|
class NixlConnector(KVConnectorBase_V1):
|
||||||
@ -666,22 +683,18 @@ class NixlConnectorScheduler:
|
|||||||
# Loop through scheduled reqs and convert to ReqMeta.
|
# Loop through scheduled reqs and convert to ReqMeta.
|
||||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||||
assert req.kv_transfer_params is not None
|
assert req.kv_transfer_params is not None
|
||||||
meta.add_new_req(
|
meta.add_new_req_to_recv(
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
local_block_ids=block_ids,
|
local_block_ids=block_ids,
|
||||||
kv_transfer_params=req.kv_transfer_params,
|
kv_transfer_params=req.kv_transfer_params,
|
||||||
load_remote_cache=True,
|
|
||||||
save_to_host=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
||||||
assert req.kv_transfer_params is not None
|
assert req.kv_transfer_params is not None
|
||||||
meta.add_new_req(
|
meta.add_new_req_to_save(
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
local_block_ids=block_ids,
|
local_block_ids=block_ids,
|
||||||
kv_transfer_params=req.kv_transfer_params,
|
kv_transfer_params=req.kv_transfer_params,
|
||||||
load_remote_cache=False,
|
|
||||||
save_to_host=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
meta.reqs_to_send = self._reqs_need_send
|
meta.reqs_to_send = self._reqs_need_send
|
||||||
@ -1124,10 +1137,11 @@ class NixlConnectorWorker:
|
|||||||
# Do NIXL handshake in background and add to _ready_requests when done.
|
# Do NIXL handshake in background and add to _ready_requests when done.
|
||||||
fut = self._handshake_futures.get(remote_engine_id)
|
fut = self._handshake_futures.get(remote_engine_id)
|
||||||
if fut is None:
|
if fut is None:
|
||||||
|
assert meta.remote is not None
|
||||||
fut = self._handshake_initiation_executor.submit(
|
fut = self._handshake_initiation_executor.submit(
|
||||||
self._nixl_handshake,
|
self._nixl_handshake,
|
||||||
meta.remote_host,
|
meta.remote.host,
|
||||||
meta.remote_port,
|
meta.remote.port,
|
||||||
meta.tp_size,
|
meta.tp_size,
|
||||||
remote_engine_id,
|
remote_engine_id,
|
||||||
)
|
)
|
||||||
@ -1774,6 +1788,7 @@ class NixlConnectorWorker:
|
|||||||
# clean up metadata for completed requests
|
# clean up metadata for completed requests
|
||||||
meta = self._recving_metadata.pop(req_id, None)
|
meta = self._recving_metadata.pop(req_id, None)
|
||||||
assert meta is not None, f"{req_id} not found in recving_metadata list"
|
assert meta is not None, f"{req_id} not found in recving_metadata list"
|
||||||
|
assert meta.remote is not None
|
||||||
if self.use_host_buffer:
|
if self.use_host_buffer:
|
||||||
self.sync_recved_kv_to_device(req_id, meta)
|
self.sync_recved_kv_to_device(req_id, meta)
|
||||||
if self.enable_permute_local_kv:
|
if self.enable_permute_local_kv:
|
||||||
@ -1781,7 +1796,7 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# post processing for heteroblocksize
|
# post processing for heteroblocksize
|
||||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
|
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
|
||||||
meta.remote_engine_id
|
meta.remote.engine_id
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
not self.use_mla
|
not self.use_mla
|
||||||
@ -1916,17 +1931,18 @@ class NixlConnectorWorker:
|
|||||||
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
|
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
|
||||||
meta.local_block_ids
|
meta.local_block_ids
|
||||||
)
|
)
|
||||||
meta.remote_block_ids = self._logical_to_kernel_block_ids(
|
assert meta.remote is not None
|
||||||
meta.remote_block_ids
|
meta.remote.block_ids = self._logical_to_kernel_block_ids(
|
||||||
|
meta.remote.block_ids
|
||||||
)
|
)
|
||||||
remote_engine_id = meta.remote_engine_id
|
remote_engine_id = meta.remote.engine_id
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"start_load_kv for request %s from remote engine %s. "
|
"start_load_kv for request %s from remote engine %s. "
|
||||||
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
|
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
|
||||||
req_id,
|
req_id,
|
||||||
remote_engine_id,
|
remote_engine_id,
|
||||||
len(meta.local_physical_block_ids),
|
len(meta.local_physical_block_ids),
|
||||||
len(meta.remote_block_ids),
|
len(meta.remote.block_ids),
|
||||||
)
|
)
|
||||||
# always store metadata for failure recovery
|
# always store metadata for failure recovery
|
||||||
self._recving_metadata[req_id] = meta
|
self._recving_metadata[req_id] = meta
|
||||||
@ -1965,17 +1981,18 @@ class NixlConnectorWorker:
|
|||||||
self._reqs_to_send[req_id] = expiration_time
|
self._reqs_to_send[req_id] = expiration_time
|
||||||
|
|
||||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||||
|
assert meta.remote is not None
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Remote agent %s available, calling _read_blocks for req %s",
|
"Remote agent %s available, calling _read_blocks for req %s",
|
||||||
meta.remote_engine_id,
|
meta.remote.engine_id,
|
||||||
req_id,
|
req_id,
|
||||||
)
|
)
|
||||||
self._read_blocks(
|
self._read_blocks(
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
dst_engine_id=meta.remote_engine_id,
|
dst_engine_id=meta.remote.engine_id,
|
||||||
remote_request_id=meta.remote_request_id,
|
remote_request_id=meta.remote.request_id,
|
||||||
local_block_ids=meta.local_physical_block_ids,
|
local_block_ids=meta.local_physical_block_ids,
|
||||||
remote_block_ids=meta.remote_block_ids,
|
remote_block_ids=meta.remote.block_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _read_blocks(
|
def _read_blocks(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user