mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 21:55:49 +08:00
[P/D] Move FakeNixlWrapper to test dir (#21328)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
parent
d9f9a3fd96
commit
1e9ea8e69d
@ -1,10 +1,15 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -16,30 +21,118 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
|||||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||||
NixlConnectorWorker)
|
NixlConnectorWorker)
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
from .utils import create_request, create_scheduler, create_vllm_config
|
from .utils import create_request, create_scheduler, create_vllm_config
|
||||||
|
|
||||||
|
|
||||||
def _make_stub_pkg() -> str:
|
class FakeNixlWrapper:
|
||||||
"""Return a directory that makes
|
"""Mock implementation of NixlWrapper for testing.
|
||||||
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper."""
|
|
||||||
td = tempfile.mkdtemp()
|
|
||||||
pkg_root = os.path.join(td, "nixl", "_api")
|
|
||||||
os.makedirs(pkg_root, exist_ok=True)
|
|
||||||
|
|
||||||
stub = textwrap.dedent("""\
|
We don't inherit from nixl._api.nixl_agent because nixl may not be
|
||||||
# Forward the real FakeNixlWrapper that the driver already defined.
|
installed.
|
||||||
print("In fake package")
|
|
||||||
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent
|
Note: The complete source of this class is also used in the
|
||||||
""")
|
`_make_fake_nixl_pkg` function to create a fake nixl package
|
||||||
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
|
for Ray workers.
|
||||||
f.write(stub)
|
"""
|
||||||
|
|
||||||
# touch parent package
|
AGENT_METADATA = b"fake_agent_metadata"
|
||||||
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
|
REMOTE_AGENT_NAME = "remote_agent"
|
||||||
return td
|
|
||||||
|
def __init__(self, agent_name: str, *args, **kwargs):
|
||||||
|
self._cycles_before_xfer_done = 0
|
||||||
|
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
|
||||||
|
lambda: 0)
|
||||||
|
|
||||||
|
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
||||||
|
return [str(uuid.uuid4()) for _ in caches_data]
|
||||||
|
|
||||||
|
def register_memory(self, descs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
||||||
|
return [str(uuid.uuid4()) for _ in blocks_data]
|
||||||
|
|
||||||
|
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
|
||||||
|
return uuid.uuid4().int
|
||||||
|
|
||||||
|
def get_agent_metadata(self) -> bytes:
|
||||||
|
return self.AGENT_METADATA
|
||||||
|
|
||||||
|
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
||||||
|
return self.REMOTE_AGENT_NAME
|
||||||
|
|
||||||
|
def get_new_notifs(self) -> dict[str, list[bytes]]:
|
||||||
|
# Used to collect done_sending, which we don't test yet.
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def check_xfer_state(self, handle: int) -> str:
|
||||||
|
if self._check_xfer_state_cycles[
|
||||||
|
handle] >= self._cycles_before_xfer_done:
|
||||||
|
return "DONE"
|
||||||
|
self._check_xfer_state_cycles[handle] += 1
|
||||||
|
return "PROC"
|
||||||
|
|
||||||
|
def release_xfer_handle(self, handle: int) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def make_prepped_xfer(self,
|
||||||
|
xfer_type: str,
|
||||||
|
local_xfer_side_handle: int,
|
||||||
|
local_block_descs_ids: list[int],
|
||||||
|
remote_xfer_side_handle: int,
|
||||||
|
remote_block_descs_ids: list[int],
|
||||||
|
notif_msg: Optional[bytes] = None) -> int:
|
||||||
|
return uuid.uuid4().int
|
||||||
|
|
||||||
|
def transfer(self, handle: int) -> str:
|
||||||
|
return "PROC"
|
||||||
|
|
||||||
|
############################################################
|
||||||
|
# Follow are for changing the behavior during testing.
|
||||||
|
############################################################
|
||||||
|
|
||||||
|
def set_cycles_before_xfer_done(self, cycles: int):
|
||||||
|
"""Set the number of cycles before a transfer is considered done."""
|
||||||
|
self._cycles_before_xfer_done = cycles
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _make_fake_nixl_pkg():
|
||||||
|
"""Context manager that creates a temporary package making
|
||||||
|
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
|
||||||
|
|
||||||
|
Automatically cleans up the temporary directory when done.
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
pkg_root = os.path.join(td, "nixl", "_api")
|
||||||
|
os.makedirs(pkg_root, exist_ok=True)
|
||||||
|
|
||||||
|
# Get the source code of FakeNixlWrapper class and dedent it
|
||||||
|
fake_nixl_source = inspect.getsource(FakeNixlWrapper)
|
||||||
|
fake_nixl_source = textwrap.dedent(fake_nixl_source)
|
||||||
|
|
||||||
|
stub = f"""\
|
||||||
|
# Copy of FakeNixlWrapper implementation for Ray workers
|
||||||
|
import uuid
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
{fake_nixl_source}
|
||||||
|
|
||||||
|
# Export as nixl_agent
|
||||||
|
nixl_agent = FakeNixlWrapper
|
||||||
|
"""
|
||||||
|
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
|
||||||
|
f.write(stub)
|
||||||
|
|
||||||
|
# touch parent package
|
||||||
|
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
|
||||||
|
yield td
|
||||||
|
|
||||||
|
|
||||||
def test_basic_interface():
|
def test_basic_interface():
|
||||||
@ -351,27 +444,37 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
|||||||
kv_connector="NixlConnector",
|
kv_connector="NixlConnector",
|
||||||
kv_role="kv_both",
|
kv_role="kv_both",
|
||||||
)
|
)
|
||||||
|
llm_kwargs = {
|
||||||
|
"model": model_name,
|
||||||
|
"enforce_eager": True,
|
||||||
|
"gpu_memory_utilization": 0.5,
|
||||||
|
"kv_transfer_config": kv_transfer_config,
|
||||||
|
"distributed_executor_backend": distributed_executor_backend,
|
||||||
|
}
|
||||||
|
|
||||||
timeout = 6
|
timeout = 6
|
||||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
|
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
|
||||||
|
|
||||||
# Build runtime_env only if we’re using Ray
|
# Build runtime_env only if we're using Ray
|
||||||
if distributed_executor_backend == "ray":
|
if distributed_executor_backend == "ray":
|
||||||
runtime_env = {
|
with _make_fake_nixl_pkg() as working_dir:
|
||||||
"working_dir": _make_stub_pkg(), # ship stub package
|
runtime_env = {
|
||||||
"env_vars": {
|
"working_dir": working_dir, # ship fake nixl package
|
||||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
|
"env_vars": {
|
||||||
},
|
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
|
||||||
}
|
},
|
||||||
ray.init(runtime_env=runtime_env)
|
}
|
||||||
|
ray.init(runtime_env=runtime_env)
|
||||||
|
|
||||||
llm = LLM(
|
_run_abort_timeout_test(llm_kwargs, timeout)
|
||||||
model=model_name,
|
else:
|
||||||
enforce_eager=True,
|
_run_abort_timeout_test(llm_kwargs, timeout)
|
||||||
gpu_memory_utilization=0.5,
|
|
||||||
kv_transfer_config=kv_transfer_config,
|
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
|
||||||
)
|
"""Helper function to run the abort timeout test logic."""
|
||||||
|
llm = LLM(**llm_kwargs)
|
||||||
remote_prefill_opts = {
|
remote_prefill_opts = {
|
||||||
"do_remote_decode": True,
|
"do_remote_decode": True,
|
||||||
"do_remote_prefill": False,
|
"do_remote_prefill": False,
|
||||||
|
|||||||
@ -120,8 +120,8 @@ class KVOutputAggregator:
|
|||||||
output corresponding to Rank 0 for scheduler."""
|
output corresponding to Rank 0 for scheduler."""
|
||||||
|
|
||||||
def __init__(self, world_size: int):
|
def __init__(self, world_size: int):
|
||||||
# Complete transfer tracker. Used by to track finished requests
|
# Complete transfer tracker. Used to track finished requests
|
||||||
# [req_id -> n_finished_workers]
|
# [req_id -> n_remaining_workers]
|
||||||
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
|
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
|
||||||
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
|
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
|
||||||
|
|
||||||
@ -134,12 +134,10 @@ class KVOutputAggregator:
|
|||||||
remaining_count_dict: dict[str, int],
|
remaining_count_dict: dict[str, int],
|
||||||
finished_set: set[str]) -> None:
|
finished_set: set[str]) -> None:
|
||||||
for req_id in req_ids or ():
|
for req_id in req_ids or ():
|
||||||
new_count = remaining_count_dict[req_id] - 1
|
remaining_count_dict[req_id] -= 1
|
||||||
if new_count == 0:
|
if remaining_count_dict[req_id] == 0:
|
||||||
finished_set.add(req_id)
|
finished_set.add(req_id)
|
||||||
del remaining_count_dict[req_id]
|
del remaining_count_dict[req_id]
|
||||||
else:
|
|
||||||
remaining_count_dict[req_id] = new_count
|
|
||||||
|
|
||||||
finished_sending = set[str]()
|
finished_sending = set[str]()
|
||||||
finished_recving = set[str]()
|
finished_recving = set[str]()
|
||||||
|
|||||||
@ -1,76 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import uuid
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class FakeNixlWrapper:
|
|
||||||
"""Mock implementation of NixlWrapper for testing.
|
|
||||||
|
|
||||||
We don't inherit from nixl._api.nixl_agent because nixl may not be
|
|
||||||
installed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
AGENT_METADATA = b"fake_agent_metadata"
|
|
||||||
REMOTE_AGENT_NAME = "remote_agent"
|
|
||||||
|
|
||||||
def __init__(self, agent_name: str, *args, **kwargs):
|
|
||||||
self._cycles_before_xfer_done = 0
|
|
||||||
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
|
|
||||||
lambda: 0)
|
|
||||||
|
|
||||||
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
|
||||||
return [str(uuid.uuid4()) for _ in caches_data]
|
|
||||||
|
|
||||||
def register_memory(self, descs) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
|
||||||
return [str(uuid.uuid4()) for _ in blocks_data]
|
|
||||||
|
|
||||||
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
|
|
||||||
return uuid.uuid4().int
|
|
||||||
|
|
||||||
def get_agent_metadata(self) -> bytes:
|
|
||||||
return self.AGENT_METADATA
|
|
||||||
|
|
||||||
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
|
||||||
return self.REMOTE_AGENT_NAME
|
|
||||||
|
|
||||||
def get_new_notifs(self) -> dict[str, list[bytes]]:
|
|
||||||
# Used to collect done_sending, which we don't test yet.
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def check_xfer_state(self, handle: int) -> str:
|
|
||||||
if self._check_xfer_state_cycles[
|
|
||||||
handle] >= self._cycles_before_xfer_done:
|
|
||||||
return "DONE"
|
|
||||||
self._check_xfer_state_cycles[handle] += 1
|
|
||||||
return "PROC"
|
|
||||||
|
|
||||||
def release_xfer_handle(self, handle: int) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def make_prepped_xfer(self,
|
|
||||||
xfer_type: str,
|
|
||||||
local_xfer_side_handle: int,
|
|
||||||
local_block_descs_ids: list[int],
|
|
||||||
remote_xfer_side_handle: int,
|
|
||||||
remote_block_descs_ids: list[int],
|
|
||||||
notif_msg: Optional[bytes] = None) -> int:
|
|
||||||
return uuid.uuid4().int
|
|
||||||
|
|
||||||
def transfer(self, handle: int) -> str:
|
|
||||||
return "PROC"
|
|
||||||
|
|
||||||
############################################################
|
|
||||||
# Follow are for changing the behavior during testing.
|
|
||||||
############################################################
|
|
||||||
|
|
||||||
def set_cycles_before_xfer_done(self, cycles: int):
|
|
||||||
"""Set the number of cycles before a transfer is considered done."""
|
|
||||||
self._cycles_before_xfer_done = cycles
|
|
||||||
Loading…
x
Reference in New Issue
Block a user