v1: Pass KVConnectorOutput to scheduler-side (#22157)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri 2025-08-09 09:09:51 +03:00 committed by GitHub
parent 6ade99eafa
commit 7ad7adb67f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 0 deletions

View File

@ -12,6 +12,8 @@ The class provides the following primitives:
times for a given request and should be side-effect free. times for a given request and should be side-effect free.
update_state_after_alloc() - update KVConnector state after update_state_after_alloc() - update KVConnector state after
temporary buffer alloc by the CacheManager. temporary buffer alloc by the CacheManager.
update_connector_output() - update KVConnector state after
output is received from worker-side connectors.
request_finished() - called when a request is finished, with request_finished() - called when a request is finished, with
the computed kv cache blocks for the request. the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or will be Returns whether KV cache should be freed now or will be
@ -38,6 +40,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
@ -283,6 +286,16 @@ class KVConnectorBase_V1(ABC):
""" """
pass pass
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
return
def request_finished( def request_finished(
self, self,
request: "Request", request: "Request",

View File

@ -14,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
@ -177,6 +178,10 @@ class MultiConnector(KVConnectorBase_V1):
self._extra_async_saves = {} self._extra_async_saves = {}
return metadata return metadata
def update_connector_output(self, connector_output: KVConnectorOutput):
for c in self._connectors:
c.update_connector_output(connector_output)
def request_finished( def request_finished(
self, self,
request: "Request", request: "Request",

View File

@ -1150,6 +1150,10 @@ class Scheduler(SchedulerInterface):
# if finished_recving: add to state so we can # if finished_recving: add to state so we can
scheduler the request during the next step. scheduler the request during the next step.
""" """
assert self.connector is not None
self.connector.update_connector_output(kv_connector_output)
# KV Connector:: update recv and send status from last step. # KV Connector:: update recv and send status from last step.
for req_id in (kv_connector_output.finished_recving or ()): for req_id in (kv_connector_output.finished_recving or ()):
logger.debug("Finished recving KV transfer for request %s", req_id) logger.debug("Finished recving KV transfer for request %s", req_id)