diff --git a/tests/v1/kv_connector/kv_load_exception_handling/random_drop_connector.py b/tests/v1/kv_connector/kv_load_exception_handling/random_drop_connector.py new file mode 100644 index 000000000000..216029a6ad7a --- /dev/null +++ b/tests/v1/kv_connector/kv_load_exception_handling/random_drop_connector.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +import random +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO) + + +@dataclass +class RandomDropConnectorMetadata(KVConnectorMetadata): + req_meta: dict[str, list[int]] + + +class RandomDropConnector(KVConnectorBase_V1): + """ + A connector designed for fault tolerance testing by randomly dropping + kv data during the process of loading or receiving KV cache. + + This class simulates real-world scenarios where requests or data + might be lost or timeout, allowing developers to test and validate the + system's ability to handle such failures. + + Attributes: + finished_recving_kv_req_ids (set[str]): A set of request IDs that + have completed receiving KV cache data. + finished_loading_dict (dict[str, int]): A dictionary that tracks + the actual number of tokens loaded from the remote KV store + for each completed request. The keys are request IDs, and + the values are the corresponding token counts. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + + self.failure_request: list[str] = [] + self._reqs_need_recv: dict[str, list[int]] = {} + self._finish_load: dict[str, int] = {} + + self.chunk_size = 256 + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + if request.request_id in self.failure_request: + self.failure_request.remove(request.request_id) + return 0, False + num_external_hit_tokens = request.num_prompt_tokens - 1 + logger.info( + "request %s num_prompt_tokens %d num_external_hit_tokens %d", + request.request_id, request.num_prompt_tokens, + num_external_hit_tokens) + return num_external_hit_tokens, True + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + if num_external_tokens > 0: + self._reqs_need_recv[ + request. + request_id] = request.prompt_token_ids[:num_external_tokens] + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + req_meta = self._reqs_need_recv.copy() + self._reqs_need_recv.clear() + return RandomDropConnectorMetadata(req_meta) + + def add_failure_request(self, request: "Request"): + self.failure_request.append(request.request_id) + + def start_load_kv(self, forward_context, **kwargs) -> None: + for request_id, hit_tokens in self._get_connector_metadata( + ).req_meta.items(): + num_actual_load_tokens = self.load_kv(request_id, hit_tokens) + logger.info("request %s hit_tokens %d num_actual_load_tokens %d", + request_id, len(hit_tokens), num_actual_load_tokens) + self._finish_load[request_id] = num_actual_load_tokens + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + pass + + def wait_for_save(self): + pass + + def load_kv(self, request_id, hit_tokens): + num_actual_load_tokens = random.randint(0, len(hit_tokens)) + return num_actual_load_tokens + + def get_finished_loading(self) -> dict[str, int]: + if not self._finish_load: + return {} + finished_loading = self._finish_load.copy() + self._finish_load.clear() + + return finished_loading diff --git a/tests/v1/kv_connector/kv_load_exception_handling/test.sh b/tests/v1/kv_connector/kv_load_exception_handling/test.sh new file mode 100644 index 000000000000..443251153319 --- /dev/null +++ b/tests/v1/kv_connector/kv_load_exception_handling/test.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +export PYTHONPATH=$PYTHONPATH:$SCRIPT_DIR + +vllm serve DeepSeek-V2-Lite-Chat \ +--trust-remote-code \ +--served-model-name vllm_cpu_offload \ +--max-model-len 32768 \ +--no-enable-prefix-caching \ +--max-seq-len-to-capture 10000 \ +--max-num-seqs 64 \ +--gpu-memory-utilization 0.9 \ +--host 0.0.0.0 \ +-tp 2 \ +--kv-transfer-config '{"kv_connector":"RandomDropConnector","kv_role":"kv_both","kv_connector_module_path":"random_drop_connector"}' diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 459a53298914..fed4349277c6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -139,13 +139,27 @@ class KVOutputAggregator: finished_set.add(req_id) del remaining_count_dict[req_id] + def update_finished_load_dict(worker_finished_loading_dict: dict[str, + int], + finished_loading_dict: dict[str, int]): + for req_id, num_actual_load_tokens in (worker_finished_loading_dict + or {}).items(): + if req_id in finished_loading_dict: + finished_loading_dict[req_id] = min( + finished_loading_dict[req_id], num_actual_load_tokens) + else: + finished_loading_dict[req_id] = num_actual_load_tokens + finished_sending = set[str]() finished_recving = set[str]() + finished_loading_dict: dict[str, int] = {} for output in outputs: update_finished_set(output.finished_sending, self._send_remaining_count, finished_sending) update_finished_set(output.finished_recving, self._recv_remaining_count, finished_recving) + update_finished_load_dict(output.finished_loading_dict, + finished_loading_dict) # select output of the worker specified by output_rank output = outputs[output_rank] @@ -157,7 +171,7 @@ class KVOutputAggregator: # send/recv output.finished_sending = finished_sending if finished_sending else None output.finished_recving = finished_recving if finished_recving else None - + output.finished_loading_dict = finished_loading_dict or None return output def async_aggregate(self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 8bbdd7e0621c..9dfb6a08867a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -28,6 +28,9 @@ The class provides the following primitives: get_finished() - called with ids of finished requests, returns ids of requests that have completed async sending/recving. + get_finished_loading() - called with scheduler outputs, returns + a dictionary that the keys are request IDs and the values are + the actual number of tokens loaded from the remote KV cache """ import enum @@ -219,6 +222,23 @@ class KVConnectorBase_V1(ABC): """ return None, None + def get_finished_loading( + self, scheduler_output: "SchedulerOutput") -> dict[str, int]: + """ + Retrieves the actual number of tokens loaded for requests that have + completed the asynchronous loading process from the remote KV cache. + + This function is used by the scheduler process (via the Executors) + to track the progress of requests and determine which requests have + successfully finished loading their KV cache data. + + Returns: + A dictionary where the keys are request IDs and the values are the + corresponding number of tokens that have been successfully loaded + for each request. + """ + return {} + # ============================== # Scheduler-side methods # ============================== diff --git a/vllm/sequence.py b/vllm/sequence.py index fe87b52f9df1..1a9095ebee12 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1167,6 +1167,8 @@ class IntermediateTensors: # [req_ids] finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None + #req_id -> num_actual_load_tokens + finished_loading_dict: Optional[dict[str, int]] = None def __init__(self, tensors): # manually define this function, so that diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 446f98034cb8..2907c3f27b24 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -118,6 +118,9 @@ class Scheduler(SchedulerInterface): # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() + # The keys are request IDs, and the values are corresponding token + # count that have been successfully loaded from the remote KV store + self.finished_loading_dict: dict[str, int] = {} # Encoder-related. # Calculate encoder cache size if applicable @@ -1094,6 +1097,27 @@ class Scheduler(SchedulerInterface): (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) return self.connector.request_finished(request, block_ids) + def _update_actual_load_token_num_from_remote_kv(self, + request: Request) -> bool: + + num_actual_load_tokens = self.finished_loading_dict.pop( + request.request_id) + num_computed_tokens = num_actual_load_tokens + assert self.connector is not None + if num_actual_load_tokens <= 0 and hasattr(self.connector, + "add_failure_request"): + self.connector.add_failure_request(request) + return True + + if num_actual_load_tokens == request.num_tokens: + num_computed_tokens -= 1 + + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + return True + def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ KV Connector: check if the request_id is finished_recving. @@ -1107,6 +1131,9 @@ class Scheduler(SchedulerInterface): WAITING_FOR_REMOTE_KV. """ assert self.connector is not None + if request.request_id in self.finished_loading_dict: + return self._update_actual_load_token_num_from_remote_kv(request) + if request.request_id not in self.finished_recving_kv_req_ids: return False @@ -1145,3 +1172,6 @@ class Scheduler(SchedulerInterface): for req_id in (model_runner_output.finished_sending or ()): logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id]) + if model_runner_output.finished_loading_dict: + self.finished_loading_dict.update( + model_runner_output.finished_loading_dict) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f78623f571b2..0b757a297db4 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -107,6 +107,8 @@ class ModelRunnerOutput: # [req_ids] finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None + # req_id -> actual_load_token from connector + finished_loading_dict: Optional[dict[str, int]] = None # req_id -> num_nans_in_logits num_nans_in_logits: Optional[dict[str, int]] = None @@ -121,4 +123,5 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], pooler_output=[], finished_sending=None, finished_recving=None, + finished_loading_dict=None, num_nans_in_logits=None) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8d63ee923e6c..bf670de324a7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1350,6 +1350,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens_np: np.ndarray, finished_sending: Optional[set[str]], finished_recving: Optional[set[str]], + finished_loading_dict: Optional[dict[str, int]], ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ @@ -1386,6 +1387,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pooler_output=pooler_output, finished_sending=finished_sending, finished_recving=finished_recving, + finished_loading_dict=finished_loading_dict, ) @torch.inference_mode() @@ -1505,6 +1507,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.maybe_wait_for_kv_save() finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) + finished_loading_dict = self.get_finished_loading(scheduler_output) if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output @@ -1522,9 +1525,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: - if finished_sending or finished_recving: + if (finished_sending or finished_recving + or finished_loading_dict): hidden_states.finished_sending = finished_sending hidden_states.finished_recving = finished_recving + hidden_states.finished_loading_dict = finished_loading_dict return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict(hidden_states.tensors, @@ -1534,7 +1539,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, num_scheduled_tokens_np, finished_sending, - finished_recving) + finished_recving, finished_loading_dict) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1686,6 +1691,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pooler_output=[], finished_sending=finished_sending, finished_recving=finished_recving, + finished_loading_dict=finished_loading_dict, num_nans_in_logits=num_nans_in_logits, ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d9d1f14f0554..50618c9ce8b8 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -359,10 +359,12 @@ class Worker(WorkerBase): # In case of PP with kv transfer, we need to pass through the # finished_sending and finished_recving buffers. new_output = EMPTY_MODEL_RUNNER_OUTPUT - if output.finished_sending or output.finished_recving: + if (output.finished_sending or output.finished_recving + or output.finished_loading_dict): new_output = copy.copy(new_output) new_output.finished_sending = output.finished_sending new_output.finished_recving = output.finished_recving + new_output.finished_loading_dict = output.finished_loading_dict output = new_output assert isinstance(output, ModelRunnerOutput) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 5a3186058fcf..d3204ca47f19 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -53,6 +53,14 @@ class KVConnectorModelRunnerMixin: scheduler_output.finished_req_ids) return None, None + @staticmethod + def get_finished_loading( + scheduler_output: "SchedulerOutput", ) -> dict[str, int]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_finished_loading( + scheduler_output) + return {} + def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput", vllm_config: VllmConfig) -> ModelRunnerOutput: # KV send/recv even if no work to do. @@ -60,11 +68,14 @@ class KVConnectorModelRunnerMixin: self.maybe_setup_kv_connector(scheduler_output) finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) + finished_loading_dict = self.get_finished_loading(scheduler_output) - if not finished_sending and not finished_recving: + if (not finished_sending and not finished_recving + and not finished_loading_dict): return EMPTY_MODEL_RUNNER_OUTPUT output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.finished_sending = finished_sending output.finished_recving = finished_recving + output.finished_loading_dict = finished_loading_dict return output