mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 08:49:08 +08:00
Revert "[V1] Exception Handling when Loading KV Cache from Remote Store" (#21778)
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
parent
9ba1c88a93
commit
b18b417fbf
@ -1,120 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|
||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
||||||
from vllm.v1.request import Request
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RandomDropConnectorMetadata(KVConnectorMetadata):
|
|
||||||
req_meta: dict[str, list[int]]
|
|
||||||
|
|
||||||
|
|
||||||
class RandomDropConnector(KVConnectorBase_V1):
|
|
||||||
"""
|
|
||||||
A connector designed for fault tolerance testing by randomly dropping
|
|
||||||
kv data during the process of loading or receiving KV cache.
|
|
||||||
|
|
||||||
This class simulates real-world scenarios where requests or data
|
|
||||||
might be lost or timeout, allowing developers to test and validate the
|
|
||||||
system's ability to handle such failures.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
finished_recving_kv_req_ids (set[str]): A set of request IDs that
|
|
||||||
have completed receiving KV cache data.
|
|
||||||
finished_loading_dict (dict[str, int]): A dictionary that tracks
|
|
||||||
the actual number of tokens loaded from the remote KV store
|
|
||||||
for each completed request. The keys are request IDs, and
|
|
||||||
the values are the corresponding token counts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
|
||||||
super().__init__(vllm_config=vllm_config, role=role)
|
|
||||||
|
|
||||||
self.failure_request: list[str] = []
|
|
||||||
self._reqs_need_recv: dict[str, list[int]] = {}
|
|
||||||
self._finish_load: dict[str, int] = {}
|
|
||||||
|
|
||||||
self.chunk_size = 256
|
|
||||||
|
|
||||||
############################################################
|
|
||||||
# Scheduler Side Methods
|
|
||||||
############################################################
|
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
|
||||||
self, request: "Request",
|
|
||||||
num_computed_tokens: int) -> tuple[int, bool]:
|
|
||||||
if request.request_id in self.failure_request:
|
|
||||||
self.failure_request.remove(request.request_id)
|
|
||||||
return 0, False
|
|
||||||
num_external_hit_tokens = request.num_prompt_tokens - 1
|
|
||||||
logger.info(
|
|
||||||
"request %s num_prompt_tokens %d num_external_hit_tokens %d",
|
|
||||||
request.request_id, request.num_prompt_tokens,
|
|
||||||
num_external_hit_tokens)
|
|
||||||
return num_external_hit_tokens, True
|
|
||||||
|
|
||||||
def update_state_after_alloc(self, request: "Request",
|
|
||||||
blocks: "KVCacheBlocks",
|
|
||||||
num_external_tokens: int):
|
|
||||||
if num_external_tokens > 0:
|
|
||||||
self._reqs_need_recv[
|
|
||||||
request.
|
|
||||||
request_id] = request.prompt_token_ids[:num_external_tokens]
|
|
||||||
|
|
||||||
def build_connector_meta(
|
|
||||||
self,
|
|
||||||
scheduler_output: SchedulerOutput,
|
|
||||||
) -> KVConnectorMetadata:
|
|
||||||
req_meta = self._reqs_need_recv.copy()
|
|
||||||
self._reqs_need_recv.clear()
|
|
||||||
return RandomDropConnectorMetadata(req_meta)
|
|
||||||
|
|
||||||
def add_failure_request(self, request: "Request"):
|
|
||||||
self.failure_request.append(request.request_id)
|
|
||||||
|
|
||||||
def start_load_kv(self, forward_context, **kwargs) -> None:
|
|
||||||
for request_id, hit_tokens in self._get_connector_metadata(
|
|
||||||
).req_meta.items():
|
|
||||||
num_actual_load_tokens = self.load_kv(request_id, hit_tokens)
|
|
||||||
logger.info("request %s hit_tokens %d num_actual_load_tokens %d",
|
|
||||||
request_id, len(hit_tokens), num_actual_load_tokens)
|
|
||||||
self._finish_load[request_id] = num_actual_load_tokens
|
|
||||||
|
|
||||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
|
||||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def wait_for_save(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load_kv(self, request_id, hit_tokens):
|
|
||||||
num_actual_load_tokens = random.randint(0, len(hit_tokens))
|
|
||||||
return num_actual_load_tokens
|
|
||||||
|
|
||||||
def get_finished_loading(self) -> dict[str, int]:
|
|
||||||
if not self._finish_load:
|
|
||||||
return {}
|
|
||||||
finished_loading = self._finish_load.copy()
|
|
||||||
self._finish_load.clear()
|
|
||||||
|
|
||||||
return finished_loading
|
|
||||||
@ -1,16 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
|
||||||
export PYTHONPATH=$PYTHONPATH:$SCRIPT_DIR
|
|
||||||
|
|
||||||
vllm serve DeepSeek-V2-Lite-Chat \
|
|
||||||
--trust-remote-code \
|
|
||||||
--served-model-name vllm_cpu_offload \
|
|
||||||
--max-model-len 32768 \
|
|
||||||
--no-enable-prefix-caching \
|
|
||||||
--max-seq-len-to-capture 10000 \
|
|
||||||
--max-num-seqs 64 \
|
|
||||||
--gpu-memory-utilization 0.9 \
|
|
||||||
--host 0.0.0.0 \
|
|
||||||
-tp 2 \
|
|
||||||
--kv-transfer-config '{"kv_connector":"RandomDropConnector","kv_role":"kv_both","kv_connector_module_path":"random_drop_connector"}'
|
|
||||||
@ -139,27 +139,13 @@ class KVOutputAggregator:
|
|||||||
finished_set.add(req_id)
|
finished_set.add(req_id)
|
||||||
del remaining_count_dict[req_id]
|
del remaining_count_dict[req_id]
|
||||||
|
|
||||||
def update_finished_load_dict(worker_finished_loading_dict: dict[str,
|
|
||||||
int],
|
|
||||||
finished_loading_dict: dict[str, int]):
|
|
||||||
for req_id, num_actual_load_tokens in (worker_finished_loading_dict
|
|
||||||
or {}).items():
|
|
||||||
if req_id in finished_loading_dict:
|
|
||||||
finished_loading_dict[req_id] = min(
|
|
||||||
finished_loading_dict[req_id], num_actual_load_tokens)
|
|
||||||
else:
|
|
||||||
finished_loading_dict[req_id] = num_actual_load_tokens
|
|
||||||
|
|
||||||
finished_sending = set[str]()
|
finished_sending = set[str]()
|
||||||
finished_recving = set[str]()
|
finished_recving = set[str]()
|
||||||
finished_loading_dict: dict[str, int] = {}
|
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
update_finished_set(output.finished_sending,
|
update_finished_set(output.finished_sending,
|
||||||
self._send_remaining_count, finished_sending)
|
self._send_remaining_count, finished_sending)
|
||||||
update_finished_set(output.finished_recving,
|
update_finished_set(output.finished_recving,
|
||||||
self._recv_remaining_count, finished_recving)
|
self._recv_remaining_count, finished_recving)
|
||||||
update_finished_load_dict(output.finished_loading_dict,
|
|
||||||
finished_loading_dict)
|
|
||||||
|
|
||||||
# select output of the worker specified by output_rank
|
# select output of the worker specified by output_rank
|
||||||
output = outputs[output_rank]
|
output = outputs[output_rank]
|
||||||
@ -171,7 +157,7 @@ class KVOutputAggregator:
|
|||||||
# send/recv
|
# send/recv
|
||||||
output.finished_sending = finished_sending if finished_sending else None
|
output.finished_sending = finished_sending if finished_sending else None
|
||||||
output.finished_recving = finished_recving if finished_recving else None
|
output.finished_recving = finished_recving if finished_recving else None
|
||||||
output.finished_loading_dict = finished_loading_dict or None
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def async_aggregate(self,
|
def async_aggregate(self,
|
||||||
|
|||||||
@ -28,9 +28,6 @@ The class provides the following primitives:
|
|||||||
|
|
||||||
get_finished() - called with ids of finished requests, returns
|
get_finished() - called with ids of finished requests, returns
|
||||||
ids of requests that have completed async sending/recving.
|
ids of requests that have completed async sending/recving.
|
||||||
get_finished_loading() - called with scheduler outputs, returns
|
|
||||||
a dictionary that the keys are request IDs and the values are
|
|
||||||
the actual number of tokens loaded from the remote KV cache
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
@ -222,23 +219,6 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def get_finished_loading(
|
|
||||||
self, scheduler_output: "SchedulerOutput") -> dict[str, int]:
|
|
||||||
"""
|
|
||||||
Retrieves the actual number of tokens loaded for requests that have
|
|
||||||
completed the asynchronous loading process from the remote KV cache.
|
|
||||||
|
|
||||||
This function is used by the scheduler process (via the Executors)
|
|
||||||
to track the progress of requests and determine which requests have
|
|
||||||
successfully finished loading their KV cache data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary where the keys are request IDs and the values are the
|
|
||||||
corresponding number of tokens that have been successfully loaded
|
|
||||||
for each request.
|
|
||||||
"""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Scheduler-side methods
|
# Scheduler-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
|
|||||||
@ -1167,8 +1167,6 @@ class IntermediateTensors:
|
|||||||
# [req_ids]
|
# [req_ids]
|
||||||
finished_sending: Optional[set[str]] = None
|
finished_sending: Optional[set[str]] = None
|
||||||
finished_recving: Optional[set[str]] = None
|
finished_recving: Optional[set[str]] = None
|
||||||
#req_id -> num_actual_load_tokens
|
|
||||||
finished_loading_dict: Optional[dict[str, int]] = None
|
|
||||||
|
|
||||||
def __init__(self, tensors):
|
def __init__(self, tensors):
|
||||||
# manually define this function, so that
|
# manually define this function, so that
|
||||||
|
|||||||
@ -118,9 +118,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# KV Connector: requests in process of async KV loading or recving
|
# KV Connector: requests in process of async KV loading or recving
|
||||||
self.finished_recving_kv_req_ids: set[str] = set()
|
self.finished_recving_kv_req_ids: set[str] = set()
|
||||||
# The keys are request IDs, and the values are corresponding token
|
|
||||||
# count that have been successfully loaded from the remote KV store
|
|
||||||
self.finished_loading_dict: dict[str, int] = {}
|
|
||||||
|
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
# Calculate encoder cache size if applicable
|
# Calculate encoder cache size if applicable
|
||||||
@ -1097,27 +1094,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
|
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||||
return self.connector.request_finished(request, block_ids)
|
return self.connector.request_finished(request, block_ids)
|
||||||
|
|
||||||
def _update_actual_load_token_num_from_remote_kv(self,
|
|
||||||
request: Request) -> bool:
|
|
||||||
|
|
||||||
num_actual_load_tokens = self.finished_loading_dict.pop(
|
|
||||||
request.request_id)
|
|
||||||
num_computed_tokens = num_actual_load_tokens
|
|
||||||
assert self.connector is not None
|
|
||||||
if num_actual_load_tokens <= 0 and hasattr(self.connector,
|
|
||||||
"add_failure_request"):
|
|
||||||
self.connector.add_failure_request(request)
|
|
||||||
return True
|
|
||||||
|
|
||||||
if num_actual_load_tokens == request.num_tokens:
|
|
||||||
num_computed_tokens -= 1
|
|
||||||
|
|
||||||
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
|
|
||||||
|
|
||||||
# Update the request state for scheduling.
|
|
||||||
request.num_computed_tokens = num_computed_tokens
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
||||||
"""
|
"""
|
||||||
KV Connector: check if the request_id is finished_recving.
|
KV Connector: check if the request_id is finished_recving.
|
||||||
@ -1131,9 +1107,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
WAITING_FOR_REMOTE_KV.
|
WAITING_FOR_REMOTE_KV.
|
||||||
"""
|
"""
|
||||||
assert self.connector is not None
|
assert self.connector is not None
|
||||||
if request.request_id in self.finished_loading_dict:
|
|
||||||
return self._update_actual_load_token_num_from_remote_kv(request)
|
|
||||||
|
|
||||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1172,6 +1145,3 @@ class Scheduler(SchedulerInterface):
|
|||||||
for req_id in (model_runner_output.finished_sending or ()):
|
for req_id in (model_runner_output.finished_sending or ()):
|
||||||
logger.debug("Finished sending KV transfer for request %s", req_id)
|
logger.debug("Finished sending KV transfer for request %s", req_id)
|
||||||
self._free_blocks(self.requests[req_id])
|
self._free_blocks(self.requests[req_id])
|
||||||
if model_runner_output.finished_loading_dict:
|
|
||||||
self.finished_loading_dict.update(
|
|
||||||
model_runner_output.finished_loading_dict)
|
|
||||||
|
|||||||
@ -107,8 +107,6 @@ class ModelRunnerOutput:
|
|||||||
# [req_ids]
|
# [req_ids]
|
||||||
finished_sending: Optional[set[str]] = None
|
finished_sending: Optional[set[str]] = None
|
||||||
finished_recving: Optional[set[str]] = None
|
finished_recving: Optional[set[str]] = None
|
||||||
# req_id -> actual_load_token from connector
|
|
||||||
finished_loading_dict: Optional[dict[str, int]] = None
|
|
||||||
|
|
||||||
# req_id -> num_nans_in_logits
|
# req_id -> num_nans_in_logits
|
||||||
num_nans_in_logits: Optional[dict[str, int]] = None
|
num_nans_in_logits: Optional[dict[str, int]] = None
|
||||||
@ -123,5 +121,4 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
|||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
finished_sending=None,
|
finished_sending=None,
|
||||||
finished_recving=None,
|
finished_recving=None,
|
||||||
finished_loading_dict=None,
|
|
||||||
num_nans_in_logits=None)
|
num_nans_in_logits=None)
|
||||||
|
|||||||
@ -1375,7 +1375,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_scheduled_tokens_np: np.ndarray,
|
num_scheduled_tokens_np: np.ndarray,
|
||||||
finished_sending: Optional[set[str]],
|
finished_sending: Optional[set[str]],
|
||||||
finished_recving: Optional[set[str]],
|
finished_recving: Optional[set[str]],
|
||||||
finished_loading_dict: Optional[dict[str, int]],
|
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
assert self.input_batch.num_reqs ==\
|
assert self.input_batch.num_reqs ==\
|
||||||
len(self.input_batch.pooling_params), \
|
len(self.input_batch.pooling_params), \
|
||||||
@ -1412,7 +1411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
pooler_output=pooler_output,
|
pooler_output=pooler_output,
|
||||||
finished_sending=finished_sending,
|
finished_sending=finished_sending,
|
||||||
finished_recving=finished_recving,
|
finished_recving=finished_recving,
|
||||||
finished_loading_dict=finished_loading_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -1532,7 +1530,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.maybe_wait_for_kv_save()
|
self.maybe_wait_for_kv_save()
|
||||||
finished_sending, finished_recving = (
|
finished_sending, finished_recving = (
|
||||||
self.get_finished_kv_transfers(scheduler_output))
|
self.get_finished_kv_transfers(scheduler_output))
|
||||||
finished_loading_dict = self.get_finished_loading(scheduler_output)
|
|
||||||
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
hidden_states, aux_hidden_states = model_output
|
hidden_states, aux_hidden_states = model_output
|
||||||
@ -1550,11 +1547,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
# For mid-pipeline stages, return the hidden states.
|
# For mid-pipeline stages, return the hidden states.
|
||||||
if not broadcast_pp_output:
|
if not broadcast_pp_output:
|
||||||
if (finished_sending or finished_recving
|
if finished_sending or finished_recving:
|
||||||
or finished_loading_dict):
|
|
||||||
hidden_states.finished_sending = finished_sending
|
hidden_states.finished_sending = finished_sending
|
||||||
hidden_states.finished_recving = finished_recving
|
hidden_states.finished_recving = finished_recving
|
||||||
hidden_states.finished_loading_dict = finished_loading_dict
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
assert isinstance(hidden_states, IntermediateTensors)
|
assert isinstance(hidden_states, IntermediateTensors)
|
||||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
||||||
@ -1564,7 +1559,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if self.input_batch.pooling_params:
|
if self.input_batch.pooling_params:
|
||||||
return self._pool(hidden_states, num_scheduled_tokens,
|
return self._pool(hidden_states, num_scheduled_tokens,
|
||||||
num_scheduled_tokens_np, finished_sending,
|
num_scheduled_tokens_np, finished_sending,
|
||||||
finished_recving, finished_loading_dict)
|
finished_recving)
|
||||||
|
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
@ -1716,7 +1711,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
finished_sending=finished_sending,
|
finished_sending=finished_sending,
|
||||||
finished_recving=finished_recving,
|
finished_recving=finished_recving,
|
||||||
finished_loading_dict=finished_loading_dict,
|
|
||||||
num_nans_in_logits=num_nans_in_logits,
|
num_nans_in_logits=num_nans_in_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -359,12 +359,10 @@ class Worker(WorkerBase):
|
|||||||
# In case of PP with kv transfer, we need to pass through the
|
# In case of PP with kv transfer, we need to pass through the
|
||||||
# finished_sending and finished_recving buffers.
|
# finished_sending and finished_recving buffers.
|
||||||
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
if (output.finished_sending or output.finished_recving
|
if output.finished_sending or output.finished_recving:
|
||||||
or output.finished_loading_dict):
|
|
||||||
new_output = copy.copy(new_output)
|
new_output = copy.copy(new_output)
|
||||||
new_output.finished_sending = output.finished_sending
|
new_output.finished_sending = output.finished_sending
|
||||||
new_output.finished_recving = output.finished_recving
|
new_output.finished_recving = output.finished_recving
|
||||||
new_output.finished_loading_dict = output.finished_loading_dict
|
|
||||||
output = new_output
|
output = new_output
|
||||||
|
|
||||||
assert isinstance(output, ModelRunnerOutput)
|
assert isinstance(output, ModelRunnerOutput)
|
||||||
|
|||||||
@ -53,14 +53,6 @@ class KVConnectorModelRunnerMixin:
|
|||||||
scheduler_output.finished_req_ids)
|
scheduler_output.finished_req_ids)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_finished_loading(
|
|
||||||
scheduler_output: "SchedulerOutput", ) -> dict[str, int]:
|
|
||||||
if has_kv_transfer_group():
|
|
||||||
return get_kv_transfer_group().get_finished_loading(
|
|
||||||
scheduler_output)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
|
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
|
||||||
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
||||||
# KV send/recv even if no work to do.
|
# KV send/recv even if no work to do.
|
||||||
@ -68,14 +60,11 @@ class KVConnectorModelRunnerMixin:
|
|||||||
self.maybe_setup_kv_connector(scheduler_output)
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
finished_sending, finished_recving = (
|
finished_sending, finished_recving = (
|
||||||
self.get_finished_kv_transfers(scheduler_output))
|
self.get_finished_kv_transfers(scheduler_output))
|
||||||
finished_loading_dict = self.get_finished_loading(scheduler_output)
|
|
||||||
|
|
||||||
if (not finished_sending and not finished_recving
|
if not finished_sending and not finished_recving:
|
||||||
and not finished_loading_dict):
|
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
output.finished_sending = finished_sending
|
output.finished_sending = finished_sending
|
||||||
output.finished_recving = finished_recving
|
output.finished_recving = finished_recving
|
||||||
output.finished_loading_dict = finished_loading_dict
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user