mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:57:45 +08:00
[BugFix] Make PD work with Ray (#21072)
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
This commit is contained in:
parent
6a971ed692
commit
9f414a12ad
@ -1,13 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import textwrap
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Optional
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import ray
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
from vllm.config import KVTransferConfig
|
from vllm.config import KVTransferConfig
|
||||||
@ -15,11 +16,32 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
|||||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||||
NixlConnectorWorker)
|
NixlConnectorWorker)
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
from .utils import create_request, create_scheduler, create_vllm_config
|
from .utils import create_request, create_scheduler, create_vllm_config
|
||||||
|
|
||||||
|
|
||||||
|
def _make_stub_pkg() -> str:
|
||||||
|
"""Return a directory that makes
|
||||||
|
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper."""
|
||||||
|
td = tempfile.mkdtemp()
|
||||||
|
pkg_root = os.path.join(td, "nixl", "_api")
|
||||||
|
os.makedirs(pkg_root, exist_ok=True)
|
||||||
|
|
||||||
|
stub = textwrap.dedent("""\
|
||||||
|
# Forward the real FakeNixlWrapper that the driver already defined.
|
||||||
|
print("In fake package")
|
||||||
|
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent
|
||||||
|
""")
|
||||||
|
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
|
||||||
|
f.write(stub)
|
||||||
|
|
||||||
|
# touch parent package
|
||||||
|
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
|
||||||
|
return td
|
||||||
|
|
||||||
|
|
||||||
def test_basic_interface():
|
def test_basic_interface():
|
||||||
"""Unit test for basic NixlConnector interface functionality."""
|
"""Unit test for basic NixlConnector interface functionality."""
|
||||||
|
|
||||||
@ -87,77 +109,6 @@ def test_prompt_less_than_block_size():
|
|||||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
|
|
||||||
|
|
||||||
class FakeNixlWrapper:
|
|
||||||
"""Mock implementation of NixlWrapper for testing.
|
|
||||||
|
|
||||||
We don't inherit from nixl._api.nixl_agent because nixl may not be
|
|
||||||
installed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
AGENT_METADATA = b"fake_agent_metadata"
|
|
||||||
REMOTE_AGENT_NAME = "remote_agent"
|
|
||||||
|
|
||||||
def __init__(self, agent_name: str, *args, **kwargs):
|
|
||||||
self._cycles_before_xfer_done = 0
|
|
||||||
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
|
|
||||||
lambda: 0)
|
|
||||||
|
|
||||||
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
|
||||||
return [str(uuid.uuid4()) for _ in caches_data]
|
|
||||||
|
|
||||||
def register_memory(self, descs) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
|
||||||
return [str(uuid.uuid4()) for _ in blocks_data]
|
|
||||||
|
|
||||||
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
|
|
||||||
return uuid.uuid4().int
|
|
||||||
|
|
||||||
def get_agent_metadata(self) -> bytes:
|
|
||||||
return self.AGENT_METADATA
|
|
||||||
|
|
||||||
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
|
||||||
return self.REMOTE_AGENT_NAME
|
|
||||||
|
|
||||||
def get_new_notifs(self) -> dict[str, list[bytes]]:
|
|
||||||
# Used to collect done_sending, which we don't test yet.
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def check_xfer_state(self, handle: int) -> str:
|
|
||||||
if self._check_xfer_state_cycles[
|
|
||||||
handle] >= self._cycles_before_xfer_done:
|
|
||||||
return "DONE"
|
|
||||||
self._check_xfer_state_cycles[handle] += 1
|
|
||||||
return "PROC"
|
|
||||||
|
|
||||||
def release_xfer_handle(self, handle: int) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def make_prepped_xfer(self,
|
|
||||||
xfer_type: str,
|
|
||||||
local_xfer_side_handle: int,
|
|
||||||
local_block_descs_ids: list[int],
|
|
||||||
remote_xfer_side_handle: int,
|
|
||||||
remote_block_descs_ids: list[int],
|
|
||||||
notif_msg: Optional[bytes] = None) -> int:
|
|
||||||
return uuid.uuid4().int
|
|
||||||
|
|
||||||
def transfer(self, handle: int) -> str:
|
|
||||||
return "PROC"
|
|
||||||
|
|
||||||
############################################################
|
|
||||||
# Follow are for changing the behavior during testing.
|
|
||||||
############################################################
|
|
||||||
|
|
||||||
def set_cycles_before_xfer_done(self, cycles: int):
|
|
||||||
"""Set the number of cycles before a transfer is considered done."""
|
|
||||||
self._cycles_before_xfer_done = cycles
|
|
||||||
|
|
||||||
|
|
||||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||||
|
|
||||||
REMOTE_ENGINE_ID = "remote_engine"
|
REMOTE_ENGINE_ID = "remote_engine"
|
||||||
@ -378,10 +329,14 @@ class TestNixlHandshake:
|
|||||||
raise TimeoutError("Took too long to complete async handshake.")
|
raise TimeoutError("Took too long to complete async handshake.")
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
|
||||||
|
# we put here is important. First run ray, it will clean up the resources, then
|
||||||
|
# the rest of the tests.
|
||||||
|
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
|
||||||
@patch(
|
@patch(
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||||
FakeNixlWrapper)
|
FakeNixlWrapper)
|
||||||
def test_abort_timeout_on_prefiller(monkeypatch):
|
def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
||||||
"""
|
"""
|
||||||
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
|
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
|
||||||
-----> P
|
-----> P
|
||||||
@ -399,11 +354,23 @@ def test_abort_timeout_on_prefiller(monkeypatch):
|
|||||||
timeout = 6
|
timeout = 6
|
||||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
|
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
|
||||||
|
|
||||||
|
# Build runtime_env only if we’re using Ray
|
||||||
|
if distributed_executor_backend == "ray":
|
||||||
|
runtime_env = {
|
||||||
|
"working_dir": _make_stub_pkg(), # ship stub package
|
||||||
|
"env_vars": {
|
||||||
|
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ray.init(runtime_env=runtime_env)
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
gpu_memory_utilization=0.5,
|
gpu_memory_utilization=0.5,
|
||||||
kv_transfer_config=kv_transfer_config,
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
)
|
)
|
||||||
remote_prefill_opts = {
|
remote_prefill_opts = {
|
||||||
"do_remote_decode": True,
|
"do_remote_decode": True,
|
||||||
|
|||||||
@ -1,28 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import threading
|
|
||||||
from collections import defaultdict
|
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
|
||||||
|
|
||||||
class DummyMultiprocExecutor(MultiprocExecutor):
|
|
||||||
|
|
||||||
def __init__(self, output_rank, world_size):
|
|
||||||
# Manually initialize minimal required fields
|
|
||||||
self.output_rank = output_rank
|
|
||||||
self.world_size = world_size
|
|
||||||
self._send_remaining_count = defaultdict[str,
|
|
||||||
int](lambda: self.world_size)
|
|
||||||
self._recv_remaining_count = defaultdict[str,
|
|
||||||
int](lambda: self.world_size)
|
|
||||||
self.io_thread_pool = None
|
|
||||||
self.shutdown_event = threading.Event()
|
|
||||||
|
|
||||||
|
|
||||||
class DummyModelRunnerOutput(ModelRunnerOutput):
|
class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -33,14 +17,14 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
|
|||||||
|
|
||||||
|
|
||||||
def test_aggregate_workers_output():
|
def test_aggregate_workers_output():
|
||||||
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
|
aggregator = KVOutputAggregator(world_size=2)
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||||
finished_recving={'req2'})
|
finished_recving={'req2'})
|
||||||
output2 = DummyModelRunnerOutput(finished_sending=None,
|
output2 = DummyModelRunnerOutput(finished_sending=None,
|
||||||
finished_recving=None)
|
finished_recving=None)
|
||||||
|
|
||||||
aggregated = executor._aggregate_workers_output([output1, output2])
|
aggregated = aggregator.aggregate([output1, output2])
|
||||||
|
|
||||||
assert aggregated is output1
|
assert aggregated is output1
|
||||||
assert aggregated.finished_sending is None
|
assert aggregated.finished_sending is None
|
||||||
@ -51,7 +35,7 @@ def test_aggregate_workers_output():
|
|||||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||||
finished_recving=None)
|
finished_recving=None)
|
||||||
|
|
||||||
aggregated = executor._aggregate_workers_output([output1, output2])
|
aggregated = aggregator.aggregate([output1, output2])
|
||||||
|
|
||||||
assert aggregated is output1
|
assert aggregated is output1
|
||||||
assert aggregated.finished_sending == {'req1'}
|
assert aggregated.finished_sending == {'req1'}
|
||||||
@ -62,7 +46,7 @@ def test_aggregate_workers_output():
|
|||||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||||
finished_recving={'req2'})
|
finished_recving={'req2'})
|
||||||
|
|
||||||
aggregated = executor._aggregate_workers_output([output1, output2])
|
aggregated = aggregator.aggregate([output1, output2])
|
||||||
|
|
||||||
assert aggregated is output1
|
assert aggregated is output1
|
||||||
assert aggregated.finished_sending is None
|
assert aggregated.finished_sending is None
|
||||||
@ -70,12 +54,11 @@ def test_aggregate_workers_output():
|
|||||||
|
|
||||||
|
|
||||||
def test_async_aggregate_workers_output():
|
def test_async_aggregate_workers_output():
|
||||||
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
|
aggregator = KVOutputAggregator(world_size=2)
|
||||||
|
|
||||||
future1: Future[DummyModelRunnerOutput] = Future()
|
future1: Future[DummyModelRunnerOutput] = Future()
|
||||||
future2: Future[DummyModelRunnerOutput] = Future()
|
future2: Future[DummyModelRunnerOutput] = Future()
|
||||||
result_future = executor._async_aggregate_workers_output(
|
result_future = aggregator.async_aggregate([future1, future2])
|
||||||
[future1, future2])
|
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||||
finished_recving={'req2'})
|
finished_recving={'req2'})
|
||||||
@ -92,8 +75,7 @@ def test_async_aggregate_workers_output():
|
|||||||
|
|
||||||
future1 = Future()
|
future1 = Future()
|
||||||
future2 = Future()
|
future2 = Future()
|
||||||
result_future = executor._async_aggregate_workers_output(
|
result_future = aggregator.async_aggregate([future1, future2])
|
||||||
[future1, future2])
|
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||||
finished_recving=None)
|
finished_recving=None)
|
||||||
@ -110,8 +92,7 @@ def test_async_aggregate_workers_output():
|
|||||||
|
|
||||||
future1 = Future()
|
future1 = Future()
|
||||||
future2 = Future()
|
future2 = Future()
|
||||||
result_future = executor._async_aggregate_workers_output(
|
result_future = aggregator.async_aggregate([future1, future2])
|
||||||
[future1, future2])
|
|
||||||
|
|
||||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||||
finished_recving=None)
|
finished_recving=None)
|
||||||
@ -3,12 +3,18 @@
|
|||||||
"""
|
"""
|
||||||
KV cache helper for store.
|
KV cache helper for store.
|
||||||
"""
|
"""
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from concurrent.futures import CancelledError, Future
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -107,3 +113,87 @@ def get_kv_connector_cache_layout():
|
|||||||
"layout to HND for better xfer performance.")
|
"layout to HND for better xfer performance.")
|
||||||
return "HND"
|
return "HND"
|
||||||
return "NHD"
|
return "NHD"
|
||||||
|
|
||||||
|
|
||||||
|
class KVOutputAggregator:
|
||||||
|
"""Utility class to aggregate the output of all workers into a single
|
||||||
|
output corresponding to Rank 0 for scheduler."""
|
||||||
|
|
||||||
|
def __init__(self, world_size: int):
|
||||||
|
# Complete transfer tracker. Used by to track finished requests
|
||||||
|
# [req_id -> n_finished_workers]
|
||||||
|
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
|
||||||
|
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
|
||||||
|
|
||||||
|
def aggregate(self,
|
||||||
|
outputs: list[ModelRunnerOutput],
|
||||||
|
output_rank: int = 0) -> ModelRunnerOutput:
|
||||||
|
# aggregate finished_sending, finished_recving from all workers
|
||||||
|
|
||||||
|
def update_finished_set(req_ids: Optional[set[str]],
|
||||||
|
remaining_count_dict: dict[str, int],
|
||||||
|
finished_set: set[str]) -> None:
|
||||||
|
for req_id in req_ids or ():
|
||||||
|
new_count = remaining_count_dict[req_id] - 1
|
||||||
|
if new_count == 0:
|
||||||
|
finished_set.add(req_id)
|
||||||
|
del remaining_count_dict[req_id]
|
||||||
|
else:
|
||||||
|
remaining_count_dict[req_id] = new_count
|
||||||
|
|
||||||
|
finished_sending = set[str]()
|
||||||
|
finished_recving = set[str]()
|
||||||
|
for output in outputs:
|
||||||
|
update_finished_set(output.finished_sending,
|
||||||
|
self._send_remaining_count, finished_sending)
|
||||||
|
update_finished_set(output.finished_recving,
|
||||||
|
self._recv_remaining_count, finished_recving)
|
||||||
|
|
||||||
|
# select output of the worker specified by output_rank
|
||||||
|
output = outputs[output_rank]
|
||||||
|
|
||||||
|
# set the aggregated finished_sending / finished_recving
|
||||||
|
# if output.finished_sending/recving is not empty, but the other ranks
|
||||||
|
# still have unfinished send/recv, we want to set the aggregated
|
||||||
|
# finished_sending/recving to None until all ranks have finished
|
||||||
|
# send/recv
|
||||||
|
output.finished_sending = finished_sending if finished_sending else None
|
||||||
|
output.finished_recving = finished_recving if finished_recving else None
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def async_aggregate(self,
|
||||||
|
output_futures: Sequence[Future[ModelRunnerOutput]],
|
||||||
|
output_rank: int = 0) -> Future[ModelRunnerOutput]:
|
||||||
|
"""Takes a list of futures and returns a single future which resolves
|
||||||
|
to the respective list of outputs."""
|
||||||
|
result_future: Future[ModelRunnerOutput] = Future()
|
||||||
|
|
||||||
|
outputs: list[Optional[ModelRunnerOutput]] = [None
|
||||||
|
] * len(output_futures)
|
||||||
|
|
||||||
|
def make_callback(idx):
|
||||||
|
|
||||||
|
def callback(fut):
|
||||||
|
if result_future.done():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
outputs[idx] = fut.result()
|
||||||
|
except CancelledError:
|
||||||
|
result_future.cancel()
|
||||||
|
except Exception as e:
|
||||||
|
result_future.set_exception(e)
|
||||||
|
|
||||||
|
# this check assumes io_thread_pool uses a single thread
|
||||||
|
if all(outputs):
|
||||||
|
result_future.set_result(
|
||||||
|
self.aggregate(cast(list[ModelRunnerOutput], outputs),
|
||||||
|
output_rank))
|
||||||
|
|
||||||
|
return callback
|
||||||
|
|
||||||
|
for i, output_future in enumerate(output_futures):
|
||||||
|
output_future.add_done_callback(make_callback(i))
|
||||||
|
|
||||||
|
return result_future
|
||||||
|
|||||||
@ -194,7 +194,7 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
Notifies worker-side connector ids of requests that have
|
Notifies worker-side connector ids of requests that have
|
||||||
finished generating tokens on the worker.
|
finished generating tokens on the worker.
|
||||||
The scheduler process (via the MultiprocExecutor) will use this output
|
The scheduler process (via the Executors) will use this output
|
||||||
to track which workers are done.
|
to track which workers are done.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
0
vllm/mocks/__init__.py
Normal file
0
vllm/mocks/__init__.py
Normal file
76
vllm/mocks/mock_nixl_connector.py
Normal file
76
vllm/mocks/mock_nixl_connector.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import uuid
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class FakeNixlWrapper:
|
||||||
|
"""Mock implementation of NixlWrapper for testing.
|
||||||
|
|
||||||
|
We don't inherit from nixl._api.nixl_agent because nixl may not be
|
||||||
|
installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
AGENT_METADATA = b"fake_agent_metadata"
|
||||||
|
REMOTE_AGENT_NAME = "remote_agent"
|
||||||
|
|
||||||
|
def __init__(self, agent_name: str, *args, **kwargs):
|
||||||
|
self._cycles_before_xfer_done = 0
|
||||||
|
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
|
||||||
|
lambda: 0)
|
||||||
|
|
||||||
|
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
||||||
|
return [str(uuid.uuid4()) for _ in caches_data]
|
||||||
|
|
||||||
|
def register_memory(self, descs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
||||||
|
return [str(uuid.uuid4()) for _ in blocks_data]
|
||||||
|
|
||||||
|
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
|
||||||
|
return uuid.uuid4().int
|
||||||
|
|
||||||
|
def get_agent_metadata(self) -> bytes:
|
||||||
|
return self.AGENT_METADATA
|
||||||
|
|
||||||
|
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
||||||
|
return self.REMOTE_AGENT_NAME
|
||||||
|
|
||||||
|
def get_new_notifs(self) -> dict[str, list[bytes]]:
|
||||||
|
# Used to collect done_sending, which we don't test yet.
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def check_xfer_state(self, handle: int) -> str:
|
||||||
|
if self._check_xfer_state_cycles[
|
||||||
|
handle] >= self._cycles_before_xfer_done:
|
||||||
|
return "DONE"
|
||||||
|
self._check_xfer_state_cycles[handle] += 1
|
||||||
|
return "PROC"
|
||||||
|
|
||||||
|
def release_xfer_handle(self, handle: int) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def make_prepped_xfer(self,
|
||||||
|
xfer_type: str,
|
||||||
|
local_xfer_side_handle: int,
|
||||||
|
local_block_descs_ids: list[int],
|
||||||
|
remote_xfer_side_handle: int,
|
||||||
|
remote_block_descs_ids: list[int],
|
||||||
|
notif_msg: Optional[bytes] = None) -> int:
|
||||||
|
return uuid.uuid4().int
|
||||||
|
|
||||||
|
def transfer(self, handle: int) -> str:
|
||||||
|
return "PROC"
|
||||||
|
|
||||||
|
############################################################
|
||||||
|
# Follow are for changing the behavior during testing.
|
||||||
|
############################################################
|
||||||
|
|
||||||
|
def set_cycles_before_xfer_done(self, cycles: int):
|
||||||
|
"""Set the number of cycles before a transfer is considered done."""
|
||||||
|
self._cycles_before_xfer_done = cycles
|
||||||
@ -1188,9 +1188,15 @@ class IntermediateTensors:
|
|||||||
"""For all pipeline stages except the last, we need to return the hidden
|
"""For all pipeline stages except the last, we need to return the hidden
|
||||||
states and residuals to be sent to the next stage. This data structure
|
states and residuals to be sent to the next stage. This data structure
|
||||||
contains the hidden states and residuals for a request.
|
contains the hidden states and residuals for a request.
|
||||||
|
|
||||||
|
Each stage also needs to handle its own finished_sending and
|
||||||
|
finished_recving in case of kv transfer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tensors: dict[str, torch.Tensor]
|
tensors: dict[str, torch.Tensor]
|
||||||
|
# [req_ids]
|
||||||
|
finished_sending: Optional[set[str]] = None
|
||||||
|
finished_recving: Optional[set[str]] = None
|
||||||
|
|
||||||
def __init__(self, tensors):
|
def __init__(self, tensors):
|
||||||
# manually define this function, so that
|
# manually define this function, so that
|
||||||
|
|||||||
@ -9,8 +9,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import weakref
|
import weakref
|
||||||
from collections import defaultdict
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -27,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment,
|
|||||||
destroy_model_parallel)
|
destroy_model_parallel)
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||||
MessageQueue)
|
MessageQueue)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
from vllm.executor.multiproc_worker_utils import (
|
from vllm.executor.multiproc_worker_utils import (
|
||||||
_add_prefix, set_multiprocessing_worker_envs)
|
_add_prefix, set_multiprocessing_worker_envs)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -118,13 +118,8 @@ class MultiprocExecutor(Executor):
|
|||||||
|
|
||||||
self.output_rank = self._get_output_rank()
|
self.output_rank = self._get_output_rank()
|
||||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||||
|
self.kv_output_aggregator = KVOutputAggregator(
|
||||||
# Complete transfer tracker. Used by to track finished requests
|
self.parallel_config.world_size)
|
||||||
# [req_id -> n_finished_workers]
|
|
||||||
self._recv_remaining_count = defaultdict[str,
|
|
||||||
int](lambda: self.world_size)
|
|
||||||
self._send_remaining_count = defaultdict[str,
|
|
||||||
int](lambda: self.world_size)
|
|
||||||
|
|
||||||
def start_worker_monitor(self):
|
def start_worker_monitor(self):
|
||||||
workers = self.workers
|
workers = self.workers
|
||||||
@ -186,8 +181,9 @@ class MultiprocExecutor(Executor):
|
|||||||
|
|
||||||
# aggregate all workers output to a single output
|
# aggregate all workers output to a single output
|
||||||
if non_block:
|
if non_block:
|
||||||
return self._async_aggregate_workers_output(outputs)
|
return self.kv_output_aggregator.async_aggregate(
|
||||||
return self._aggregate_workers_output(outputs)
|
outputs, self.output_rank)
|
||||||
|
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable],
|
method: Union[str, Callable],
|
||||||
@ -246,74 +242,6 @@ class MultiprocExecutor(Executor):
|
|||||||
except TimeoutError as e:
|
except TimeoutError as e:
|
||||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||||
|
|
||||||
def _aggregate_workers_output(
|
|
||||||
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
|
|
||||||
# aggregate finished_sending, finished_recving from all workers
|
|
||||||
|
|
||||||
def update_finished_set(req_ids: Optional[set[str]],
|
|
||||||
remaining_count_dict: dict[str, int],
|
|
||||||
finished_set: set[str]) -> None:
|
|
||||||
for req_id in req_ids or ():
|
|
||||||
new_count = remaining_count_dict[req_id] - 1
|
|
||||||
if new_count == 0:
|
|
||||||
finished_set.add(req_id)
|
|
||||||
del remaining_count_dict[req_id]
|
|
||||||
else:
|
|
||||||
remaining_count_dict[req_id] = new_count
|
|
||||||
|
|
||||||
finished_sending = set[str]()
|
|
||||||
finished_recving = set[str]()
|
|
||||||
for output in outputs:
|
|
||||||
update_finished_set(output.finished_sending,
|
|
||||||
self._send_remaining_count, finished_sending)
|
|
||||||
update_finished_set(output.finished_recving,
|
|
||||||
self._recv_remaining_count, finished_recving)
|
|
||||||
|
|
||||||
# select output of the worker specified by output_rank
|
|
||||||
output = outputs[self.output_rank]
|
|
||||||
|
|
||||||
# set the aggregated finished_sending / finished_recving
|
|
||||||
output.finished_sending = finished_sending if finished_sending else None
|
|
||||||
output.finished_recving = finished_recving if finished_recving else None
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def _async_aggregate_workers_output(
|
|
||||||
self, output_futures: list[Future[ModelRunnerOutput]]
|
|
||||||
) -> (Future[ModelRunnerOutput]):
|
|
||||||
"""Takes a list of futures and returns a single future which resolves
|
|
||||||
to the respective list of outputs."""
|
|
||||||
result_future: Future[ModelRunnerOutput] = Future()
|
|
||||||
|
|
||||||
outputs: list[Optional[ModelRunnerOutput]] = [None
|
|
||||||
] * len(output_futures)
|
|
||||||
|
|
||||||
def make_callback(idx):
|
|
||||||
|
|
||||||
def callback(fut):
|
|
||||||
if result_future.done():
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
outputs[idx] = fut.result()
|
|
||||||
except CancelledError:
|
|
||||||
result_future.cancel()
|
|
||||||
except Exception as e:
|
|
||||||
result_future.set_exception(e)
|
|
||||||
|
|
||||||
# this check assumes io_thread_pool uses a single thread
|
|
||||||
if all(outputs):
|
|
||||||
result_future.set_result(
|
|
||||||
self._aggregate_workers_output(
|
|
||||||
cast(list[ModelRunnerOutput], outputs)))
|
|
||||||
|
|
||||||
return callback
|
|
||||||
|
|
||||||
for i, output_future in enumerate(output_futures):
|
|
||||||
output_future.add_done_callback(make_callback(i))
|
|
||||||
|
|
||||||
return result_future
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||||
"""Ensure that all worker processes are terminated. Assumes workers have
|
"""Ensure that all worker processes are terminated. Assumes workers have
|
||||||
|
|||||||
@ -2,33 +2,55 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||||
RayDistributedExecutor as RayDistributedExecutorV0)
|
RayDistributedExecutor as RayDistributedExecutorV0)
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FutureWrapper(Future):
|
class FutureWrapper(Future):
|
||||||
"""A wrapper around a Ray output reference to meet the interface
|
"""A wrapper around Ray output reference to meet the interface
|
||||||
of .execute_model().
|
of .execute_model(): The top level (core busy loop) expects .result() api
|
||||||
|
to block and return a single output.
|
||||||
|
|
||||||
|
If aggregator is provided, the outputs from all workers are aggregated upon
|
||||||
|
the result() call. If not only the first worker's output is returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ref):
|
def __init__(self, refs, aggregator: Optional[KVOutputAggregator] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ref = ref
|
self.refs = refs
|
||||||
|
self.aggregator = aggregator
|
||||||
|
|
||||||
def result(self, timeout=None):
|
def result(self, timeout=None):
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
raise NotImplementedError("timeout is not supported")
|
raise NotImplementedError("timeout is not supported")
|
||||||
return self.ref.get()
|
|
||||||
|
if self.aggregator is None:
|
||||||
|
return self.refs[0].get()
|
||||||
|
|
||||||
|
outputs = [ref.get() for ref in self.refs]
|
||||||
|
return self.aggregator.aggregate(outputs, output_rank=0)
|
||||||
|
|
||||||
|
|
||||||
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||||
"""Ray distributed executor using Ray Compiled Graphs."""
|
"""Ray distributed executor using Ray Compiled Graphs."""
|
||||||
|
|
||||||
|
def _init_executor(self) -> None:
|
||||||
|
super()._init_executor()
|
||||||
|
|
||||||
|
# KV connector setup
|
||||||
|
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||||
|
self.kv_output_aggregator = KVOutputAggregator(
|
||||||
|
self.parallel_config.world_size)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_concurrent_batches(self) -> int:
|
def max_concurrent_batches(self) -> int:
|
||||||
"""Ray distributed executor supports pipeline parallelism,
|
"""Ray distributed executor supports pipeline parallelism,
|
||||||
@ -56,13 +78,24 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
|||||||
|
|
||||||
refs = self.forward_dag.execute(scheduler_output) # type: ignore
|
refs = self.forward_dag.execute(scheduler_output) # type: ignore
|
||||||
|
|
||||||
# When PP is not used, we block here until the result is available.
|
if not self.has_connector:
|
||||||
if self.max_concurrent_batches == 1:
|
# Get output only from a single worker (output_rank)
|
||||||
return refs[0].get()
|
# When PP is not used, we block here until the result is available.
|
||||||
|
if self.max_concurrent_batches == 1:
|
||||||
|
return refs[0].get()
|
||||||
|
|
||||||
# When PP is used, we return a FutureWrapper immediately so that
|
# When PP is used, we return a FutureWrapper immediately so that
|
||||||
# the scheduler can yield to the next batch.
|
# the scheduler can yield to the next batch.
|
||||||
return FutureWrapper(refs[0])
|
return FutureWrapper(refs)
|
||||||
|
|
||||||
|
# Get output from all workers when connector is present
|
||||||
|
if self.max_concurrent_batches == 1:
|
||||||
|
# Block and get results from all workers
|
||||||
|
outputs = [ref.get() for ref in refs]
|
||||||
|
return self.kv_output_aggregator.aggregate(outputs)
|
||||||
|
|
||||||
|
# Return a future that will aggregate outputs from all workers
|
||||||
|
return FutureWrapper(refs, self.kv_output_aggregator)
|
||||||
|
|
||||||
def reinitialize_distributed(
|
def reinitialize_distributed(
|
||||||
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@ -1270,6 +1271,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
num_scheduled_tokens: int,
|
num_scheduled_tokens: int,
|
||||||
num_scheduled_tokens_np: np.ndarray,
|
num_scheduled_tokens_np: np.ndarray,
|
||||||
|
finished_sending: Optional[set[str]],
|
||||||
|
finished_recving: Optional[set[str]],
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
assert self.input_batch.num_reqs ==\
|
assert self.input_batch.num_reqs ==\
|
||||||
len(self.input_batch.pooling_params), \
|
len(self.input_batch.pooling_params), \
|
||||||
@ -1304,6 +1307,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=pooler_output,
|
pooler_output=pooler_output,
|
||||||
|
finished_sending=finished_sending,
|
||||||
|
finished_recving=finished_recving,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -1314,12 +1319,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
if not scheduler_output.total_num_scheduled_tokens:
|
||||||
if has_kv_transfer_group():
|
if not has_kv_transfer_group():
|
||||||
with set_forward_context(None, self.vllm_config):
|
# Return empty ModelRunnerOutput if there's no work to do.
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
# Return empty ModelRunnerOutput if there's no work to do.
|
return self.kv_connector_no_forward(scheduler_output)
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
||||||
|
|
||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
(attn_metadata, attention_cuda_graphs, logits_indices,
|
(attn_metadata, attention_cuda_graphs, logits_indices,
|
||||||
@ -1412,6 +1416,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.maybe_wait_for_kv_save()
|
self.maybe_wait_for_kv_save()
|
||||||
|
finished_sending, finished_recving = (
|
||||||
|
self.get_finished_kv_transfers(scheduler_output))
|
||||||
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
hidden_states, aux_hidden_states = model_output
|
hidden_states, aux_hidden_states = model_output
|
||||||
@ -1429,6 +1435,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
# For mid-pipeline stages, return the hidden states.
|
# For mid-pipeline stages, return the hidden states.
|
||||||
if not broadcast_pp_output:
|
if not broadcast_pp_output:
|
||||||
|
if finished_sending or finished_recving:
|
||||||
|
hidden_states.finished_sending = finished_sending
|
||||||
|
hidden_states.finished_recving = finished_recving
|
||||||
return hidden_states
|
return hidden_states
|
||||||
assert isinstance(hidden_states, IntermediateTensors)
|
assert isinstance(hidden_states, IntermediateTensors)
|
||||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
||||||
@ -1437,7 +1446,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
if self.input_batch.pooling_params:
|
if self.input_batch.pooling_params:
|
||||||
return self._pool(hidden_states, num_scheduled_tokens,
|
return self._pool(hidden_states, num_scheduled_tokens,
|
||||||
num_scheduled_tokens_np)
|
num_scheduled_tokens_np, finished_sending,
|
||||||
|
finished_recving)
|
||||||
|
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
@ -1587,6 +1597,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
|
finished_sending=finished_sending,
|
||||||
|
finished_recving=finished_recving,
|
||||||
num_nans_in_logits=num_nans_in_logits,
|
num_nans_in_logits=num_nans_in_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1711,6 +1723,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
get_kv_transfer_group().wait_for_save()
|
get_kv_transfer_group().wait_for_save()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_finished_kv_transfers(
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||||
|
if has_kv_transfer_group():
|
||||||
|
return get_kv_transfer_group().get_finished(
|
||||||
|
scheduler_output.finished_req_ids)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def kv_connector_no_forward(
|
||||||
|
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||||
|
# KV send/recv even if no work to do.
|
||||||
|
with set_forward_context(None, self.vllm_config):
|
||||||
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
|
finished_sending, finished_recving = (
|
||||||
|
self.get_finished_kv_transfers(scheduler_output))
|
||||||
|
|
||||||
|
if not finished_sending and not finished_recving:
|
||||||
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
|
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
|
output.finished_sending = finished_sending
|
||||||
|
output.finished_recving = finished_recving
|
||||||
|
return output
|
||||||
|
|
||||||
def propose_ngram_draft_token_ids(
|
def propose_ngram_draft_token_ids(
|
||||||
self,
|
self,
|
||||||
sampled_token_ids: list[list[int]],
|
sampled_token_ids: list[list[int]],
|
||||||
|
|||||||
@ -15,9 +15,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||||
get_kv_transfer_group,
|
|
||||||
has_kv_transfer_group)
|
|
||||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -335,25 +333,17 @@ class Worker(WorkerBase):
|
|||||||
assert isinstance(output, IntermediateTensors)
|
assert isinstance(output, IntermediateTensors)
|
||||||
get_pp_group().send_tensor_dict(output.tensors,
|
get_pp_group().send_tensor_dict(output.tensors,
|
||||||
all_gather_group=get_tp_group())
|
all_gather_group=get_tp_group())
|
||||||
output = EMPTY_MODEL_RUNNER_OUTPUT
|
|
||||||
|
# In case of PP with kv transfer, we need to pass through the
|
||||||
|
# finished_sending and finished_recving buffers.
|
||||||
|
empty_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
if output.finished_sending or output.finished_recving:
|
||||||
|
empty_output = copy.copy(empty_output)
|
||||||
|
empty_output.finished_sending = output.finished_sending
|
||||||
|
empty_output.finished_recving = output.finished_recving
|
||||||
|
output = empty_output
|
||||||
|
|
||||||
assert isinstance(output, ModelRunnerOutput)
|
assert isinstance(output, ModelRunnerOutput)
|
||||||
if has_kv_transfer_group():
|
|
||||||
finished_sending, finished_recving = (
|
|
||||||
get_kv_transfer_group().get_finished(
|
|
||||||
scheduler_output.finished_req_ids))
|
|
||||||
if finished_sending or finished_recving:
|
|
||||||
if output is EMPTY_MODEL_RUNNER_OUTPUT:
|
|
||||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
|
||||||
output.finished_sending = finished_sending
|
|
||||||
output.finished_recving = finished_recving
|
|
||||||
|
|
||||||
# Clear KVConnector state for this step.
|
|
||||||
get_kv_transfer_group().clear_connector_metadata()
|
|
||||||
|
|
||||||
# with a connector, the scheduler expects output from all workers
|
|
||||||
return output
|
|
||||||
|
|
||||||
# return output only from the driver worker
|
# return output only from the driver worker
|
||||||
return output if self.is_driver_worker else None
|
return output if self.is_driver_worker else None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user