mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
[KVConnector] Aggregate finished requests on the scheduler (#19555)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
parent
fdfd409f8f
commit
cc876d0f29
@ -190,7 +190,9 @@ class KVConnectorBase_V1(ABC):
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the MultiprocExecutor) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
|
||||
@ -408,14 +408,6 @@ class NixlConnectorWorker:
|
||||
# Track the expiration time of requests that are waiting to be sent.
|
||||
self._reqs_to_send: dict[ReqId, float] = {}
|
||||
|
||||
# Complete transfer tracker. Used by the rank 0 to track finished
|
||||
# transactions on ranks 1 to N-1.
|
||||
# [req_id -> count]
|
||||
self._done_recving_count: defaultdict[ReqId,
|
||||
int] = defaultdict(lambda: 0)
|
||||
self._done_sending_count: defaultdict[ReqId,
|
||||
int] = defaultdict(lambda: 0)
|
||||
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
||||
# Background thread for initializing new NIXL handshakes.
|
||||
@ -830,15 +822,9 @@ class NixlConnectorWorker:
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Get requests that are done sending or recving.
|
||||
|
||||
In TP>1 setup, each rank exchanges KVs with its counterpart
|
||||
ranks independently. get_finished() runs in a worker creates
|
||||
the done_sending and done_recving sets that are sent to the
|
||||
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
|
||||
are done before adding to finished, Ranks 1 to N-1 communicate
|
||||
to Rank 0 once their transaction is done + Rank 0 returns
|
||||
finished sets to Scheduler only once all ranks are done.
|
||||
Get requests that are done sending or recving on this specific worker.
|
||||
The scheduler process (via the MultiprocExecutor) will use this output
|
||||
to track which workers are done.
|
||||
"""
|
||||
done_sending = self._get_new_notifs()
|
||||
done_recving = self._pop_done_transfers(self._recving_transfers)
|
||||
@ -858,50 +844,7 @@ class NixlConnectorWorker:
|
||||
del self._reqs_to_send[req_id]
|
||||
done_sending.add(req_id)
|
||||
|
||||
if self.world_size == 1:
|
||||
return done_sending, done_recving
|
||||
|
||||
# Rank 0: get finished from all other ranks.
|
||||
if self.tp_rank == 0:
|
||||
for req_id in done_sending:
|
||||
self._done_sending_count[req_id] += 1
|
||||
for req_id in done_recving:
|
||||
self._done_recving_count[req_id] += 1
|
||||
|
||||
# Keep track of how many other ranks have finished.
|
||||
other_ranks_finished_ids: list[str] = []
|
||||
for i in range(1, self.world_size):
|
||||
other_ranks_finished_ids.extend(
|
||||
self.tp_group.recv_object(src=i))
|
||||
for req_id in other_ranks_finished_ids:
|
||||
if (req_id in self._done_recving_count
|
||||
or req_id in self._recving_transfers):
|
||||
self._done_recving_count[req_id] += 1
|
||||
else:
|
||||
self._done_sending_count[req_id] += 1
|
||||
|
||||
# Return ids that finished on all ranks to the scheduler.
|
||||
all_done_recving: set[str] = set()
|
||||
for req_id in list(self._done_recving_count.keys()):
|
||||
if self._done_recving_count[req_id] == self.world_size:
|
||||
del self._done_recving_count[req_id]
|
||||
all_done_recving.add(req_id)
|
||||
|
||||
all_done_sending: set[str] = set()
|
||||
for req_id in list(self._done_sending_count.keys()):
|
||||
if self._done_sending_count[req_id] >= self.world_size:
|
||||
del self._done_sending_count[req_id]
|
||||
all_done_sending.add(req_id)
|
||||
|
||||
return all_done_sending, all_done_recving
|
||||
|
||||
# Ranks 1 to N-1: send finished ids to Rank 0.
|
||||
else:
|
||||
finished_req_ids = list(done_recving.union(done_sending))
|
||||
self.tp_group.send_object(finished_req_ids, dst=0)
|
||||
|
||||
# Unused as only Rank 0 results are sent to scheduler.
|
||||
return done_sending, done_recving
|
||||
return done_sending, done_recving
|
||||
|
||||
def _get_new_notifs(self) -> set[str]:
|
||||
"""
|
||||
|
||||
@ -9,7 +9,8 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
@ -111,10 +112,19 @@ class MultiprocExecutor(Executor):
|
||||
if self.max_concurrent_batches > 1:
|
||||
# Note: must use only 1 IO thread to keep dequeue sequence
|
||||
# from the response queue
|
||||
# _async_aggregate_workers_output also assumes a single IO thread
|
||||
self.io_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="mp_exec_io")
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
# Complete transfer tracker. Used by to track finished requests
|
||||
# [req_id -> n_finished_workers]
|
||||
self._recv_remaining_count = defaultdict[str,
|
||||
int](lambda: self.world_size)
|
||||
self._send_remaining_count = defaultdict[str,
|
||||
int](lambda: self.world_size)
|
||||
|
||||
def start_worker_monitor(self):
|
||||
workers = self.workers
|
||||
@ -155,13 +165,29 @@ class MultiprocExecutor(Executor):
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
(output, ) = self.collective_rpc(
|
||||
non_block = self.max_concurrent_batches > 1
|
||||
|
||||
if not self.has_connector:
|
||||
# get output only from a single worker (output_rank)
|
||||
(output, ) = self.collective_rpc(
|
||||
"execute_model",
|
||||
args=(scheduler_output, ),
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
|
||||
return output
|
||||
|
||||
# get output from all workers
|
||||
outputs = self.collective_rpc(
|
||||
"execute_model",
|
||||
args=(scheduler_output, ),
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=self.max_concurrent_batches > 1,
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
|
||||
return output
|
||||
|
||||
# aggregate all workers output to a single output
|
||||
if non_block:
|
||||
return self._async_aggregate_workers_output(outputs)
|
||||
return self._aggregate_workers_output(outputs)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
@ -220,6 +246,80 @@ class MultiprocExecutor(Executor):
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||
|
||||
def _aggregate_workers_output(
|
||||
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
|
||||
# aggregate finished_sending, finished_recving from all workers
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
for output in outputs:
|
||||
# update finished_sending
|
||||
for req_id in output.finished_sending or []:
|
||||
new_count = self._send_remaining_count[req_id] - 1
|
||||
if new_count == 0:
|
||||
# got response from all workers, report back to scheduler
|
||||
finished_sending.add(req_id)
|
||||
del self._send_remaining_count[req_id]
|
||||
else:
|
||||
self._send_remaining_count[req_id] = new_count
|
||||
|
||||
# update finished_recving
|
||||
for req_id in output.finished_recving or []:
|
||||
new_count = self._recv_remaining_count[req_id] - 1
|
||||
if new_count == 0:
|
||||
# got response from all workers, report back to scheduler
|
||||
finished_recving.add(req_id)
|
||||
del self._recv_remaining_count[req_id]
|
||||
else:
|
||||
self._recv_remaining_count[req_id] = new_count
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[self.output_rank]
|
||||
|
||||
# set the aggregated finished_sending / finished_recving
|
||||
if finished_sending:
|
||||
output.finished_sending = finished_sending
|
||||
if finished_recving:
|
||||
output.finished_recving = finished_recving
|
||||
|
||||
return output
|
||||
|
||||
def _async_aggregate_workers_output(
|
||||
self, output_futures: list[Future[ModelRunnerOutput]]
|
||||
) -> (Future[ModelRunnerOutput]):
|
||||
"""Takes a list of futures and returns a single future which resolves
|
||||
to the respective list of outputs."""
|
||||
result_future: Future[ModelRunnerOutput] = Future()
|
||||
|
||||
outputs: list[Optional[ModelRunnerOutput]] = [None
|
||||
] * len(output_futures)
|
||||
|
||||
def make_callback(idx):
|
||||
|
||||
def callback(fut):
|
||||
if result_future.done():
|
||||
return
|
||||
|
||||
try:
|
||||
outputs[idx] = fut.result()
|
||||
except CancelledError:
|
||||
result_future.cancel()
|
||||
except Exception as e:
|
||||
result_future.set_exception(e)
|
||||
|
||||
# this check assumes io_thread_pool uses a single thread
|
||||
if all(outputs):
|
||||
result_future.set_result(
|
||||
self._aggregate_workers_output(
|
||||
cast(list[ModelRunnerOutput], outputs)))
|
||||
|
||||
return callback
|
||||
|
||||
for i, output_future in enumerate(output_futures):
|
||||
output_future.add_done_callback(make_callback(i))
|
||||
|
||||
return result_future
|
||||
|
||||
@staticmethod
|
||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||
"""Ensure that all worker processes are terminated. Assumes workers have
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import time
|
||||
import weakref
|
||||
@ -1234,8 +1233,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
finished_sending: Optional[set[str]],
|
||||
finished_recving: Optional[set[str]],
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.input_batch.num_reqs ==\
|
||||
len(self.input_batch.pooling_params), \
|
||||
@ -1270,8 +1267,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -1282,11 +1277,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if has_kv_transfer_group():
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
@ -1379,8 +1375,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
@ -1406,8 +1400,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving)
|
||||
num_scheduled_tokens_np)
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
@ -1560,8 +1553,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
@ -1686,22 +1677,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
|
||||
def kv_connector_no_forward(
|
||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if not finished_sending and not finished_recving:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.finished_sending = finished_sending
|
||||
output.finished_recving = finished_recving
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
@ -1723,15 +1698,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().wait_for_save()
|
||||
|
||||
@staticmethod
|
||||
def get_finished_kv_transfers(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_finished(
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def propose_ngram_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A GPU worker class."""
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
@ -14,7 +15,9 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -23,7 +26,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
@ -316,14 +319,29 @@ class Worker(WorkerBase):
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
if parallel_config.distributed_executor_backend != "external_launcher" \
|
||||
and not get_pp_group().is_last_rank:
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return None
|
||||
output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
if has_kv_transfer_group():
|
||||
finished_sending, finished_recving = (
|
||||
get_kv_transfer_group().get_finished(
|
||||
scheduler_output.finished_req_ids))
|
||||
if finished_sending or finished_recving:
|
||||
if output is EMPTY_MODEL_RUNNER_OUTPUT:
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.finished_sending = finished_sending
|
||||
output.finished_recving = finished_recving
|
||||
# with a connector, the scheduler expects output from all workers
|
||||
return output
|
||||
|
||||
# return output only from the driver worker
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user