mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 14:56:16 +08:00
[V1] Exception Handling when Loading KV Cache from Remote Store (#21534)
Signed-off-by: liuyumoye <adeline_ly2023@outlook.com> Co-authored-by: liuyumoye <adeline_ly2023@outlook.com>
This commit is contained in:
parent
04ff4be310
commit
15a72ac478
@ -0,0 +1,120 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RandomDropConnectorMetadata(KVConnectorMetadata):
|
||||||
|
req_meta: dict[str, list[int]]
|
||||||
|
|
||||||
|
|
||||||
|
class RandomDropConnector(KVConnectorBase_V1):
|
||||||
|
"""
|
||||||
|
A connector designed for fault tolerance testing by randomly dropping
|
||||||
|
kv data during the process of loading or receiving KV cache.
|
||||||
|
|
||||||
|
This class simulates real-world scenarios where requests or data
|
||||||
|
might be lost or timeout, allowing developers to test and validate the
|
||||||
|
system's ability to handle such failures.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
finished_recving_kv_req_ids (set[str]): A set of request IDs that
|
||||||
|
have completed receiving KV cache data.
|
||||||
|
finished_loading_dict (dict[str, int]): A dictionary that tracks
|
||||||
|
the actual number of tokens loaded from the remote KV store
|
||||||
|
for each completed request. The keys are request IDs, and
|
||||||
|
the values are the corresponding token counts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||||
|
super().__init__(vllm_config=vllm_config, role=role)
|
||||||
|
|
||||||
|
self.failure_request: list[str] = []
|
||||||
|
self._reqs_need_recv: dict[str, list[int]] = {}
|
||||||
|
self._finish_load: dict[str, int] = {}
|
||||||
|
|
||||||
|
self.chunk_size = 256
|
||||||
|
|
||||||
|
############################################################
|
||||||
|
# Scheduler Side Methods
|
||||||
|
############################################################
|
||||||
|
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self, request: "Request",
|
||||||
|
num_computed_tokens: int) -> tuple[int, bool]:
|
||||||
|
if request.request_id in self.failure_request:
|
||||||
|
self.failure_request.remove(request.request_id)
|
||||||
|
return 0, False
|
||||||
|
num_external_hit_tokens = request.num_prompt_tokens - 1
|
||||||
|
logger.info(
|
||||||
|
"request %s num_prompt_tokens %d num_external_hit_tokens %d",
|
||||||
|
request.request_id, request.num_prompt_tokens,
|
||||||
|
num_external_hit_tokens)
|
||||||
|
return num_external_hit_tokens, True
|
||||||
|
|
||||||
|
def update_state_after_alloc(self, request: "Request",
|
||||||
|
blocks: "KVCacheBlocks",
|
||||||
|
num_external_tokens: int):
|
||||||
|
if num_external_tokens > 0:
|
||||||
|
self._reqs_need_recv[
|
||||||
|
request.
|
||||||
|
request_id] = request.prompt_token_ids[:num_external_tokens]
|
||||||
|
|
||||||
|
def build_connector_meta(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
) -> KVConnectorMetadata:
|
||||||
|
req_meta = self._reqs_need_recv.copy()
|
||||||
|
self._reqs_need_recv.clear()
|
||||||
|
return RandomDropConnectorMetadata(req_meta)
|
||||||
|
|
||||||
|
def add_failure_request(self, request: "Request"):
|
||||||
|
self.failure_request.append(request.request_id)
|
||||||
|
|
||||||
|
def start_load_kv(self, forward_context, **kwargs) -> None:
|
||||||
|
for request_id, hit_tokens in self._get_connector_metadata(
|
||||||
|
).req_meta.items():
|
||||||
|
num_actual_load_tokens = self.load_kv(request_id, hit_tokens)
|
||||||
|
logger.info("request %s hit_tokens %d num_actual_load_tokens %d",
|
||||||
|
request_id, len(hit_tokens), num_actual_load_tokens)
|
||||||
|
self._finish_load[request_id] = num_actual_load_tokens
|
||||||
|
|
||||||
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||||
|
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def wait_for_save(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_kv(self, request_id, hit_tokens):
|
||||||
|
num_actual_load_tokens = random.randint(0, len(hit_tokens))
|
||||||
|
return num_actual_load_tokens
|
||||||
|
|
||||||
|
def get_finished_loading(self) -> dict[str, int]:
|
||||||
|
if not self._finish_load:
|
||||||
|
return {}
|
||||||
|
finished_loading = self._finish_load.copy()
|
||||||
|
self._finish_load.clear()
|
||||||
|
|
||||||
|
return finished_loading
|
||||||
16
tests/v1/kv_connector/kv_load_exception_handling/test.sh
Normal file
16
tests/v1/kv_connector/kv_load_exception_handling/test.sh
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||||
|
export PYTHONPATH=$PYTHONPATH:$SCRIPT_DIR
|
||||||
|
|
||||||
|
vllm serve DeepSeek-V2-Lite-Chat \
|
||||||
|
--trust-remote-code \
|
||||||
|
--served-model-name vllm_cpu_offload \
|
||||||
|
--max-model-len 32768 \
|
||||||
|
--no-enable-prefix-caching \
|
||||||
|
--max-seq-len-to-capture 10000 \
|
||||||
|
--max-num-seqs 64 \
|
||||||
|
--gpu-memory-utilization 0.9 \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
-tp 2 \
|
||||||
|
--kv-transfer-config '{"kv_connector":"RandomDropConnector","kv_role":"kv_both","kv_connector_module_path":"random_drop_connector"}'
|
||||||
@ -139,13 +139,27 @@ class KVOutputAggregator:
|
|||||||
finished_set.add(req_id)
|
finished_set.add(req_id)
|
||||||
del remaining_count_dict[req_id]
|
del remaining_count_dict[req_id]
|
||||||
|
|
||||||
|
def update_finished_load_dict(worker_finished_loading_dict: dict[str,
|
||||||
|
int],
|
||||||
|
finished_loading_dict: dict[str, int]):
|
||||||
|
for req_id, num_actual_load_tokens in (worker_finished_loading_dict
|
||||||
|
or {}).items():
|
||||||
|
if req_id in finished_loading_dict:
|
||||||
|
finished_loading_dict[req_id] = min(
|
||||||
|
finished_loading_dict[req_id], num_actual_load_tokens)
|
||||||
|
else:
|
||||||
|
finished_loading_dict[req_id] = num_actual_load_tokens
|
||||||
|
|
||||||
finished_sending = set[str]()
|
finished_sending = set[str]()
|
||||||
finished_recving = set[str]()
|
finished_recving = set[str]()
|
||||||
|
finished_loading_dict: dict[str, int] = {}
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
update_finished_set(output.finished_sending,
|
update_finished_set(output.finished_sending,
|
||||||
self._send_remaining_count, finished_sending)
|
self._send_remaining_count, finished_sending)
|
||||||
update_finished_set(output.finished_recving,
|
update_finished_set(output.finished_recving,
|
||||||
self._recv_remaining_count, finished_recving)
|
self._recv_remaining_count, finished_recving)
|
||||||
|
update_finished_load_dict(output.finished_loading_dict,
|
||||||
|
finished_loading_dict)
|
||||||
|
|
||||||
# select output of the worker specified by output_rank
|
# select output of the worker specified by output_rank
|
||||||
output = outputs[output_rank]
|
output = outputs[output_rank]
|
||||||
@ -157,7 +171,7 @@ class KVOutputAggregator:
|
|||||||
# send/recv
|
# send/recv
|
||||||
output.finished_sending = finished_sending if finished_sending else None
|
output.finished_sending = finished_sending if finished_sending else None
|
||||||
output.finished_recving = finished_recving if finished_recving else None
|
output.finished_recving = finished_recving if finished_recving else None
|
||||||
|
output.finished_loading_dict = finished_loading_dict or None
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def async_aggregate(self,
|
def async_aggregate(self,
|
||||||
|
|||||||
@ -28,6 +28,9 @@ The class provides the following primitives:
|
|||||||
|
|
||||||
get_finished() - called with ids of finished requests, returns
|
get_finished() - called with ids of finished requests, returns
|
||||||
ids of requests that have completed async sending/recving.
|
ids of requests that have completed async sending/recving.
|
||||||
|
get_finished_loading() - called with scheduler outputs, returns
|
||||||
|
a dictionary that the keys are request IDs and the values are
|
||||||
|
the actual number of tokens loaded from the remote KV cache
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
@ -219,6 +222,23 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
def get_finished_loading(
|
||||||
|
self, scheduler_output: "SchedulerOutput") -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
Retrieves the actual number of tokens loaded for requests that have
|
||||||
|
completed the asynchronous loading process from the remote KV cache.
|
||||||
|
|
||||||
|
This function is used by the scheduler process (via the Executors)
|
||||||
|
to track the progress of requests and determine which requests have
|
||||||
|
successfully finished loading their KV cache data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary where the keys are request IDs and the values are the
|
||||||
|
corresponding number of tokens that have been successfully loaded
|
||||||
|
for each request.
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Scheduler-side methods
|
# Scheduler-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
|
|||||||
@ -1167,6 +1167,8 @@ class IntermediateTensors:
|
|||||||
# [req_ids]
|
# [req_ids]
|
||||||
finished_sending: Optional[set[str]] = None
|
finished_sending: Optional[set[str]] = None
|
||||||
finished_recving: Optional[set[str]] = None
|
finished_recving: Optional[set[str]] = None
|
||||||
|
#req_id -> num_actual_load_tokens
|
||||||
|
finished_loading_dict: Optional[dict[str, int]] = None
|
||||||
|
|
||||||
def __init__(self, tensors):
|
def __init__(self, tensors):
|
||||||
# manually define this function, so that
|
# manually define this function, so that
|
||||||
|
|||||||
@ -118,6 +118,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# KV Connector: requests in process of async KV loading or recving
|
# KV Connector: requests in process of async KV loading or recving
|
||||||
self.finished_recving_kv_req_ids: set[str] = set()
|
self.finished_recving_kv_req_ids: set[str] = set()
|
||||||
|
# The keys are request IDs, and the values are corresponding token
|
||||||
|
# count that have been successfully loaded from the remote KV store
|
||||||
|
self.finished_loading_dict: dict[str, int] = {}
|
||||||
|
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
# Calculate encoder cache size if applicable
|
# Calculate encoder cache size if applicable
|
||||||
@ -1094,6 +1097,27 @@ class Scheduler(SchedulerInterface):
|
|||||||
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
|
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||||
return self.connector.request_finished(request, block_ids)
|
return self.connector.request_finished(request, block_ids)
|
||||||
|
|
||||||
|
def _update_actual_load_token_num_from_remote_kv(self,
|
||||||
|
request: Request) -> bool:
|
||||||
|
|
||||||
|
num_actual_load_tokens = self.finished_loading_dict.pop(
|
||||||
|
request.request_id)
|
||||||
|
num_computed_tokens = num_actual_load_tokens
|
||||||
|
assert self.connector is not None
|
||||||
|
if num_actual_load_tokens <= 0 and hasattr(self.connector,
|
||||||
|
"add_failure_request"):
|
||||||
|
self.connector.add_failure_request(request)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if num_actual_load_tokens == request.num_tokens:
|
||||||
|
num_computed_tokens -= 1
|
||||||
|
|
||||||
|
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
|
||||||
|
|
||||||
|
# Update the request state for scheduling.
|
||||||
|
request.num_computed_tokens = num_computed_tokens
|
||||||
|
return True
|
||||||
|
|
||||||
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
||||||
"""
|
"""
|
||||||
KV Connector: check if the request_id is finished_recving.
|
KV Connector: check if the request_id is finished_recving.
|
||||||
@ -1107,6 +1131,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
WAITING_FOR_REMOTE_KV.
|
WAITING_FOR_REMOTE_KV.
|
||||||
"""
|
"""
|
||||||
assert self.connector is not None
|
assert self.connector is not None
|
||||||
|
if request.request_id in self.finished_loading_dict:
|
||||||
|
return self._update_actual_load_token_num_from_remote_kv(request)
|
||||||
|
|
||||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1145,3 +1172,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
for req_id in (model_runner_output.finished_sending or ()):
|
for req_id in (model_runner_output.finished_sending or ()):
|
||||||
logger.debug("Finished sending KV transfer for request %s", req_id)
|
logger.debug("Finished sending KV transfer for request %s", req_id)
|
||||||
self._free_blocks(self.requests[req_id])
|
self._free_blocks(self.requests[req_id])
|
||||||
|
if model_runner_output.finished_loading_dict:
|
||||||
|
self.finished_loading_dict.update(
|
||||||
|
model_runner_output.finished_loading_dict)
|
||||||
|
|||||||
@ -107,6 +107,8 @@ class ModelRunnerOutput:
|
|||||||
# [req_ids]
|
# [req_ids]
|
||||||
finished_sending: Optional[set[str]] = None
|
finished_sending: Optional[set[str]] = None
|
||||||
finished_recving: Optional[set[str]] = None
|
finished_recving: Optional[set[str]] = None
|
||||||
|
# req_id -> actual_load_token from connector
|
||||||
|
finished_loading_dict: Optional[dict[str, int]] = None
|
||||||
|
|
||||||
# req_id -> num_nans_in_logits
|
# req_id -> num_nans_in_logits
|
||||||
num_nans_in_logits: Optional[dict[str, int]] = None
|
num_nans_in_logits: Optional[dict[str, int]] = None
|
||||||
@ -121,4 +123,5 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
|||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
finished_sending=None,
|
finished_sending=None,
|
||||||
finished_recving=None,
|
finished_recving=None,
|
||||||
|
finished_loading_dict=None,
|
||||||
num_nans_in_logits=None)
|
num_nans_in_logits=None)
|
||||||
|
|||||||
@ -1350,6 +1350,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_scheduled_tokens_np: np.ndarray,
|
num_scheduled_tokens_np: np.ndarray,
|
||||||
finished_sending: Optional[set[str]],
|
finished_sending: Optional[set[str]],
|
||||||
finished_recving: Optional[set[str]],
|
finished_recving: Optional[set[str]],
|
||||||
|
finished_loading_dict: Optional[dict[str, int]],
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
assert self.input_batch.num_reqs ==\
|
assert self.input_batch.num_reqs ==\
|
||||||
len(self.input_batch.pooling_params), \
|
len(self.input_batch.pooling_params), \
|
||||||
@ -1386,6 +1387,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
pooler_output=pooler_output,
|
pooler_output=pooler_output,
|
||||||
finished_sending=finished_sending,
|
finished_sending=finished_sending,
|
||||||
finished_recving=finished_recving,
|
finished_recving=finished_recving,
|
||||||
|
finished_loading_dict=finished_loading_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -1505,6 +1507,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.maybe_wait_for_kv_save()
|
self.maybe_wait_for_kv_save()
|
||||||
finished_sending, finished_recving = (
|
finished_sending, finished_recving = (
|
||||||
self.get_finished_kv_transfers(scheduler_output))
|
self.get_finished_kv_transfers(scheduler_output))
|
||||||
|
finished_loading_dict = self.get_finished_loading(scheduler_output)
|
||||||
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
hidden_states, aux_hidden_states = model_output
|
hidden_states, aux_hidden_states = model_output
|
||||||
@ -1522,9 +1525,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
# For mid-pipeline stages, return the hidden states.
|
# For mid-pipeline stages, return the hidden states.
|
||||||
if not broadcast_pp_output:
|
if not broadcast_pp_output:
|
||||||
if finished_sending or finished_recving:
|
if (finished_sending or finished_recving
|
||||||
|
or finished_loading_dict):
|
||||||
hidden_states.finished_sending = finished_sending
|
hidden_states.finished_sending = finished_sending
|
||||||
hidden_states.finished_recving = finished_recving
|
hidden_states.finished_recving = finished_recving
|
||||||
|
hidden_states.finished_loading_dict = finished_loading_dict
|
||||||
return hidden_states
|
return hidden_states
|
||||||
assert isinstance(hidden_states, IntermediateTensors)
|
assert isinstance(hidden_states, IntermediateTensors)
|
||||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
||||||
@ -1534,7 +1539,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if self.input_batch.pooling_params:
|
if self.input_batch.pooling_params:
|
||||||
return self._pool(hidden_states, num_scheduled_tokens,
|
return self._pool(hidden_states, num_scheduled_tokens,
|
||||||
num_scheduled_tokens_np, finished_sending,
|
num_scheduled_tokens_np, finished_sending,
|
||||||
finished_recving)
|
finished_recving, finished_loading_dict)
|
||||||
|
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
@ -1686,6 +1691,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
finished_sending=finished_sending,
|
finished_sending=finished_sending,
|
||||||
finished_recving=finished_recving,
|
finished_recving=finished_recving,
|
||||||
|
finished_loading_dict=finished_loading_dict,
|
||||||
num_nans_in_logits=num_nans_in_logits,
|
num_nans_in_logits=num_nans_in_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -359,10 +359,12 @@ class Worker(WorkerBase):
|
|||||||
# In case of PP with kv transfer, we need to pass through the
|
# In case of PP with kv transfer, we need to pass through the
|
||||||
# finished_sending and finished_recving buffers.
|
# finished_sending and finished_recving buffers.
|
||||||
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
if output.finished_sending or output.finished_recving:
|
if (output.finished_sending or output.finished_recving
|
||||||
|
or output.finished_loading_dict):
|
||||||
new_output = copy.copy(new_output)
|
new_output = copy.copy(new_output)
|
||||||
new_output.finished_sending = output.finished_sending
|
new_output.finished_sending = output.finished_sending
|
||||||
new_output.finished_recving = output.finished_recving
|
new_output.finished_recving = output.finished_recving
|
||||||
|
new_output.finished_loading_dict = output.finished_loading_dict
|
||||||
output = new_output
|
output = new_output
|
||||||
|
|
||||||
assert isinstance(output, ModelRunnerOutput)
|
assert isinstance(output, ModelRunnerOutput)
|
||||||
|
|||||||
@ -53,6 +53,14 @@ class KVConnectorModelRunnerMixin:
|
|||||||
scheduler_output.finished_req_ids)
|
scheduler_output.finished_req_ids)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_finished_loading(
|
||||||
|
scheduler_output: "SchedulerOutput", ) -> dict[str, int]:
|
||||||
|
if has_kv_transfer_group():
|
||||||
|
return get_kv_transfer_group().get_finished_loading(
|
||||||
|
scheduler_output)
|
||||||
|
return {}
|
||||||
|
|
||||||
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
|
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
|
||||||
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
||||||
# KV send/recv even if no work to do.
|
# KV send/recv even if no work to do.
|
||||||
@ -60,11 +68,14 @@ class KVConnectorModelRunnerMixin:
|
|||||||
self.maybe_setup_kv_connector(scheduler_output)
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
finished_sending, finished_recving = (
|
finished_sending, finished_recving = (
|
||||||
self.get_finished_kv_transfers(scheduler_output))
|
self.get_finished_kv_transfers(scheduler_output))
|
||||||
|
finished_loading_dict = self.get_finished_loading(scheduler_output)
|
||||||
|
|
||||||
if not finished_sending and not finished_recving:
|
if (not finished_sending and not finished_recving
|
||||||
|
and not finished_loading_dict):
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
output.finished_sending = finished_sending
|
output.finished_sending = finished_sending
|
||||||
output.finished_recving = finished_recving
|
output.finished_recving = finished_recving
|
||||||
|
output.finished_loading_dict = finished_loading_dict
|
||||||
return output
|
return output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user