[KV Connector] More async support for get_num_new_matched_tokens (#23620)

Signed-off-by: ApostaC <yihua98@uchicago.edu>
This commit is contained in:
Yihua Cheng 2025-09-09 21:23:37 -07:00 committed by GitHub
parent 83dd28aae4
commit b4a01aaf95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 23 additions and 8 deletions

View File

@ -243,7 +243,7 @@ class KVConnectorBase_V1(ABC):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.
@ -255,8 +255,11 @@ class KVConnectorBase_V1(ABC):
Returns: Returns:
A tuple with the following elements: A tuple with the following elements:
- The number of tokens that can be loaded from the - An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed. external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded - `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be asynchronously (between scheduler steps). Must be
'False' if the first element is 0. 'False' if the first element is 0.

View File

@ -110,7 +110,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.

View File

@ -143,11 +143,15 @@ class MultiConnector(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
to_return = (0, False) to_return = (0, False)
for i, c in enumerate(self._connectors): for i, c in enumerate(self._connectors):
toks, load_async = c.get_num_new_matched_tokens( toks, load_async = c.get_num_new_matched_tokens(
request, num_computed_tokens) request, num_computed_tokens)
# If there is a connector still looking up the matches,
# we return None to indicate that we are not done yet.
if toks is None:
return (None, False)
# The first connector that has new matched tokens will be assigned # The first connector that has new matched tokens will be assigned
# to this request. # to this request.
if to_return[0] == 0 and toks > 0: if to_return[0] == 0 and toks > 0:

View File

@ -162,7 +162,7 @@ class NixlConnector(KVConnectorBase_V1):
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: "Request", self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]: num_computed_tokens: int) -> tuple[Optional[int], bool]:
assert self.connector_scheduler is not None assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens( return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens) request, num_computed_tokens)

View File

@ -3,7 +3,7 @@
import hashlib import hashlib
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import safetensors import safetensors
import torch import torch
@ -238,7 +238,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.

View File

@ -387,6 +387,14 @@ class Scheduler(SchedulerInterface):
self.connector.get_num_new_matched_tokens( self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens)) request, num_new_local_computed_tokens))
if num_external_computed_tokens is None:
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens + num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens) num_external_computed_tokens)