[BugFix] Make PD work with Ray (#21072)

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
This commit is contained in:
kourosh hakhamaneshi 2025-07-19 08:46:50 -07:00 committed by GitHub
parent 6a971ed692
commit 9f414a12ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 330 additions and 222 deletions

View File

@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import tempfile
import textwrap
import time
import uuid
from collections import defaultdict
from typing import Optional
from unittest.mock import patch
import pytest
import ray
from vllm import LLM
from vllm.config import KVTransferConfig
@ -15,11 +16,32 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
from vllm.forward_context import ForwardContext
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper
from vllm.sampling_params import SamplingParams
from .utils import create_request, create_scheduler, create_vllm_config
def _make_stub_pkg() -> str:
"""Return a directory that makes
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper."""
td = tempfile.mkdtemp()
pkg_root = os.path.join(td, "nixl", "_api")
os.makedirs(pkg_root, exist_ok=True)
stub = textwrap.dedent("""\
# Forward the real FakeNixlWrapper that the driver already defined.
print("In fake package")
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent
""")
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
f.write(stub)
# touch parent package
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
return td
def test_basic_interface():
"""Unit test for basic NixlConnector interface functionality."""
@ -87,77 +109,6 @@ def test_prompt_less_than_block_size():
assert len(scheduler_output.scheduled_new_reqs) == 1
class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing.
We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
"""
AGENT_METADATA = b"fake_agent_metadata"
REMOTE_AGENT_NAME = "remote_agent"
def __init__(self, agent_name: str, *args, **kwargs):
self._cycles_before_xfer_done = 0
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
lambda: 0)
def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data]
def register_memory(self, descs) -> None:
pass
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in blocks_data]
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
return uuid.uuid4().int
def get_agent_metadata(self) -> bytes:
return self.AGENT_METADATA
def add_remote_agent(self, agent_metadata: bytes) -> str:
return self.REMOTE_AGENT_NAME
def get_new_notifs(self) -> dict[str, list[bytes]]:
# Used to collect done_sending, which we don't test yet.
return {}
def check_xfer_state(self, handle: int) -> str:
if self._check_xfer_state_cycles[
handle] >= self._cycles_before_xfer_done:
return "DONE"
self._check_xfer_state_cycles[handle] += 1
return "PROC"
def release_xfer_handle(self, handle: int) -> None:
pass
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
pass
def make_prepped_xfer(self,
xfer_type: str,
local_xfer_side_handle: int,
local_block_descs_ids: list[int],
remote_xfer_side_handle: int,
remote_block_descs_ids: list[int],
notif_msg: Optional[bytes] = None) -> int:
return uuid.uuid4().int
def transfer(self, handle: int) -> str:
return "PROC"
############################################################
# Follow are for changing the behavior during testing.
############################################################
def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
self._cycles_before_xfer_done = cycles
class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"
@ -378,10 +329,14 @@ class TestNixlHandshake:
raise TimeoutError("Took too long to complete async handshake.")
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then
# the rest of the tests.
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_abort_timeout_on_prefiller(monkeypatch):
def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
"""
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
-----> P
@ -399,11 +354,23 @@ def test_abort_timeout_on_prefiller(monkeypatch):
timeout = 6
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
# Build runtime_env only if were using Ray
if distributed_executor_backend == "ray":
runtime_env = {
"working_dir": _make_stub_pkg(), # ship stub package
"env_vars": {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
},
}
ray.init(runtime_env=runtime_env)
llm = LLM(
model=model_name,
enforce_eager=True,
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
distributed_executor_backend=distributed_executor_backend,
)
remote_prefill_opts = {
"do_remote_decode": True,

View File

@ -1,28 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections import defaultdict
from concurrent.futures import Future
from typing import Optional
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.v1.outputs import ModelRunnerOutput
class DummyMultiprocExecutor(MultiprocExecutor):
def __init__(self, output_rank, world_size):
# Manually initialize minimal required fields
self.output_rank = output_rank
self.world_size = world_size
self._send_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self._recv_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self.io_thread_pool = None
self.shutdown_event = threading.Event()
class DummyModelRunnerOutput(ModelRunnerOutput):
def __init__(self,
@ -33,14 +17,14 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
def test_aggregate_workers_output():
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
aggregator = KVOutputAggregator(world_size=2)
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output2 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
aggregated = executor._aggregate_workers_output([output1, output2])
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
assert aggregated.finished_sending is None
@ -51,7 +35,7 @@ def test_aggregate_workers_output():
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving=None)
aggregated = executor._aggregate_workers_output([output1, output2])
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
assert aggregated.finished_sending == {'req1'}
@ -62,7 +46,7 @@ def test_aggregate_workers_output():
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
aggregated = executor._aggregate_workers_output([output1, output2])
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
assert aggregated.finished_sending is None
@ -70,12 +54,11 @@ def test_aggregate_workers_output():
def test_async_aggregate_workers_output():
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
aggregator = KVOutputAggregator(world_size=2)
future1: Future[DummyModelRunnerOutput] = Future()
future2: Future[DummyModelRunnerOutput] = Future()
result_future = executor._async_aggregate_workers_output(
[future1, future2])
result_future = aggregator.async_aggregate([future1, future2])
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
@ -92,8 +75,7 @@ def test_async_aggregate_workers_output():
future1 = Future()
future2 = Future()
result_future = executor._async_aggregate_workers_output(
[future1, future2])
result_future = aggregator.async_aggregate([future1, future2])
output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
@ -110,8 +92,7 @@ def test_async_aggregate_workers_output():
future1 = Future()
future2 = Future()
result_future = executor._async_aggregate_workers_output(
[future1, future2])
result_future = aggregator.async_aggregate([future1, future2])
output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)

View File

@ -3,12 +3,18 @@
"""
KV cache helper for store.
"""
from collections import defaultdict
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import Optional, cast
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
@ -107,3 +113,87 @@ def get_kv_connector_cache_layout():
"layout to HND for better xfer performance.")
return "HND"
return "NHD"
class KVOutputAggregator:
"""Utility class to aggregate the output of all workers into a single
output corresponding to Rank 0 for scheduler."""
def __init__(self, world_size: int):
# Complete transfer tracker. Used by to track finished requests
# [req_id -> n_finished_workers]
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
def aggregate(self,
outputs: list[ModelRunnerOutput],
output_rank: int = 0) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers
def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
finished_set: set[str]) -> None:
for req_id in req_ids or ():
new_count = remaining_count_dict[req_id] - 1
if new_count == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
else:
remaining_count_dict[req_id] = new_count
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving)
# select output of the worker specified by output_rank
output = outputs[output_rank]
# set the aggregated finished_sending / finished_recving
# if output.finished_sending/recving is not empty, but the other ranks
# still have unfinished send/recv, we want to set the aggregated
# finished_sending/recving to None until all ranks have finished
# send/recv
output.finished_sending = finished_sending if finished_sending else None
output.finished_recving = finished_recving if finished_recving else None
return output
def async_aggregate(self,
output_futures: Sequence[Future[ModelRunnerOutput]],
output_rank: int = 0) -> Future[ModelRunnerOutput]:
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()
outputs: list[Optional[ModelRunnerOutput]] = [None
] * len(output_futures)
def make_callback(idx):
def callback(fut):
if result_future.done():
return
try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self.aggregate(cast(list[ModelRunnerOutput], outputs),
output_rank))
return callback
for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))
return result_future

View File

@ -194,7 +194,7 @@ class KVConnectorBase_V1(ABC):
"""
Notifies worker-side connector ids of requests that have
finished generating tokens on the worker.
The scheduler process (via the MultiprocExecutor) will use this output
The scheduler process (via the Executors) will use this output
to track which workers are done.
Returns:

0
vllm/mocks/__init__.py Normal file
View File

View File

@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import uuid
from collections import defaultdict
from typing import Optional
class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing.
We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
"""
AGENT_METADATA = b"fake_agent_metadata"
REMOTE_AGENT_NAME = "remote_agent"
def __init__(self, agent_name: str, *args, **kwargs):
self._cycles_before_xfer_done = 0
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
lambda: 0)
def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data]
def register_memory(self, descs) -> None:
pass
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in blocks_data]
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
return uuid.uuid4().int
def get_agent_metadata(self) -> bytes:
return self.AGENT_METADATA
def add_remote_agent(self, agent_metadata: bytes) -> str:
return self.REMOTE_AGENT_NAME
def get_new_notifs(self) -> dict[str, list[bytes]]:
# Used to collect done_sending, which we don't test yet.
return {}
def check_xfer_state(self, handle: int) -> str:
if self._check_xfer_state_cycles[
handle] >= self._cycles_before_xfer_done:
return "DONE"
self._check_xfer_state_cycles[handle] += 1
return "PROC"
def release_xfer_handle(self, handle: int) -> None:
pass
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
pass
def make_prepped_xfer(self,
xfer_type: str,
local_xfer_side_handle: int,
local_block_descs_ids: list[int],
remote_xfer_side_handle: int,
remote_block_descs_ids: list[int],
notif_msg: Optional[bytes] = None) -> int:
return uuid.uuid4().int
def transfer(self, handle: int) -> str:
return "PROC"
############################################################
# Follow are for changing the behavior during testing.
############################################################
def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
self._cycles_before_xfer_done = cycles

View File

@ -1188,9 +1188,15 @@ class IntermediateTensors:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
Each stage also needs to handle its own finished_sending and
finished_recving in case of kv transfer.
"""
tensors: dict[str, torch.Tensor]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
def __init__(self, tensors):
# manually define this function, so that

View File

@ -9,8 +9,7 @@ import threading
import time
import traceback
import weakref
from collections import defaultdict
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
@ -27,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.executor.multiproc_worker_utils import (
_add_prefix, set_multiprocessing_worker_envs)
from vllm.logger import init_logger
@ -118,13 +118,8 @@ class MultiprocExecutor(Executor):
self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
# Complete transfer tracker. Used by to track finished requests
# [req_id -> n_finished_workers]
self._recv_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self._send_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self.kv_output_aggregator = KVOutputAggregator(
self.parallel_config.world_size)
def start_worker_monitor(self):
workers = self.workers
@ -186,8 +181,9 @@ class MultiprocExecutor(Executor):
# aggregate all workers output to a single output
if non_block:
return self._async_aggregate_workers_output(outputs)
return self._aggregate_workers_output(outputs)
return self.kv_output_aggregator.async_aggregate(
outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
def collective_rpc(self,
method: Union[str, Callable],
@ -246,74 +242,6 @@ class MultiprocExecutor(Executor):
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
def _aggregate_workers_output(
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers
def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
finished_set: set[str]) -> None:
for req_id in req_ids or ():
new_count = remaining_count_dict[req_id] - 1
if new_count == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
else:
remaining_count_dict[req_id] = new_count
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving)
# select output of the worker specified by output_rank
output = outputs[self.output_rank]
# set the aggregated finished_sending / finished_recving
output.finished_sending = finished_sending if finished_sending else None
output.finished_recving = finished_recving if finished_recving else None
return output
def _async_aggregate_workers_output(
self, output_futures: list[Future[ModelRunnerOutput]]
) -> (Future[ModelRunnerOutput]):
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()
outputs: list[Optional[ModelRunnerOutput]] = [None
] * len(output_futures)
def make_callback(idx):
def callback(fut):
if result_future.done():
return
try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self._aggregate_workers_output(
cast(list[ModelRunnerOutput], outputs)))
return callback
for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))
return result_future
@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
"""Ensure that all worker processes are terminated. Assumes workers have

View File

@ -2,33 +2,55 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import Future
from typing import Union
from typing import Optional, Union
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0)
from vllm.logger import init_logger
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
class FutureWrapper(Future):
"""A wrapper around a Ray output reference to meet the interface
of .execute_model().
"""A wrapper around Ray output reference to meet the interface
of .execute_model(): The top level (core busy loop) expects .result() api
to block and return a single output.
If aggregator is provided, the outputs from all workers are aggregated upon
the result() call. If not only the first worker's output is returned.
"""
def __init__(self, ref):
def __init__(self, refs, aggregator: Optional[KVOutputAggregator] = None):
super().__init__()
self.ref = ref
self.refs = refs
self.aggregator = aggregator
def result(self, timeout=None):
if timeout is not None:
raise NotImplementedError("timeout is not supported")
return self.ref.get()
if self.aggregator is None:
return self.refs[0].get()
outputs = [ref.get() for ref in self.refs]
return self.aggregator.aggregate(outputs, output_rank=0)
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
"""Ray distributed executor using Ray Compiled Graphs."""
def _init_executor(self) -> None:
super()._init_executor()
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
self.kv_output_aggregator = KVOutputAggregator(
self.parallel_config.world_size)
@property
def max_concurrent_batches(self) -> int:
"""Ray distributed executor supports pipeline parallelism,
@ -56,13 +78,24 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
refs = self.forward_dag.execute(scheduler_output) # type: ignore
# When PP is not used, we block here until the result is available.
if self.max_concurrent_batches == 1:
return refs[0].get()
if not self.has_connector:
# Get output only from a single worker (output_rank)
# When PP is not used, we block here until the result is available.
if self.max_concurrent_batches == 1:
return refs[0].get()
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs[0])
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs)
# Get output from all workers when connector is present
if self.max_concurrent_batches == 1:
# Block and get results from all workers
outputs = [ref.get() for ref in refs]
return self.kv_output_aggregator.aggregate(outputs)
# Return a future that will aggregate outputs from all workers
return FutureWrapper(refs, self.kv_output_aggregator)
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest) -> None:

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import gc
import time
from contextlib import contextmanager
@ -1270,6 +1271,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]],
finished_recving: Optional[set[str]],
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
@ -1304,6 +1307,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
finished_sending=finished_sending,
finished_recving=finished_recving,
)
@torch.inference_mode()
@ -1314,12 +1319,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) -> Union[ModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if has_kv_transfer_group():
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices,
@ -1412,6 +1416,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
@ -1429,6 +1435,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
if finished_sending or finished_recving:
hidden_states.finished_sending = finished_sending
hidden_states.finished_recving = finished_recving
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(hidden_states.tensors,
@ -1437,7 +1446,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np)
num_scheduled_tokens_np, finished_sending,
finished_recving)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
@ -1587,6 +1597,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits,
)
@ -1711,6 +1723,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output
def propose_ngram_draft_token_ids(
self,
sampled_token_ids: list[list[int]],

View File

@ -15,9 +15,7 @@ from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -335,25 +333,17 @@ class Worker(WorkerBase):
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
output = EMPTY_MODEL_RUNNER_OUTPUT
# In case of PP with kv transfer, we need to pass through the
# finished_sending and finished_recving buffers.
empty_output = EMPTY_MODEL_RUNNER_OUTPUT
if output.finished_sending or output.finished_recving:
empty_output = copy.copy(empty_output)
empty_output.finished_sending = output.finished_sending
empty_output.finished_recving = output.finished_recving
output = empty_output
assert isinstance(output, ModelRunnerOutput)
if has_kv_transfer_group():
finished_sending, finished_recving = (
get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids))
if finished_sending or finished_recving:
if output is EMPTY_MODEL_RUNNER_OUTPUT:
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
# Clear KVConnector state for this step.
get_kv_transfer_group().clear_connector_metadata()
# with a connector, the scheduler expects output from all workers
return output
# return output only from the driver worker
return output if self.is_driver_worker else None