[KVConnector] Always call connector clear_metadata() at end of step (#20756)

Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: David Ben-David <sdavidbd@gmail.com>
This commit is contained in:
Nick Hill 2025-07-10 22:37:27 +01:00 committed by GitHub
parent fdadb6f43a
commit 574ad60db9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 25 additions and 26 deletions

View File

@ -57,7 +57,7 @@ class KVConnectorRole(enum.Enum):
WORKER = 1 WORKER = 1
class KVConnectorMetadata: class KVConnectorMetadata(ABC): # noqa: B024
""" """
Abstract Metadata used to communicate between the Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector. Scheduler KVConnector and Worker KVConnector.
@ -71,7 +71,7 @@ class KVConnectorBase_V1(ABC):
logger.warning( logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and " "Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design.") "subject to change in the future as we iterate the design.")
self._connector_metadata = KVConnectorMetadata() self._connector_metadata: Optional[KVConnectorMetadata] = None
self._vllm_config = vllm_config self._vllm_config = vllm_config
self._role = role self._role = role
@ -102,7 +102,7 @@ class KVConnectorBase_V1(ABC):
This function should be called by the model runner every time This function should be called by the model runner every time
after the model execution. after the model execution.
""" """
self._connector_metadata = KVConnectorMetadata() self._connector_metadata = None
def _get_connector_metadata(self) -> KVConnectorMetadata: def _get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata. """Get the connector metadata.
@ -112,6 +112,9 @@ class KVConnectorBase_V1(ABC):
Returns: Returns:
ConnectorMetadata: the connector metadata. ConnectorMetadata: the connector metadata.
""" """
# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata return self._connector_metadata
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):

View File

@ -250,28 +250,24 @@ class MultiprocExecutor(Executor):
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput: self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers # aggregate finished_sending, finished_recving from all workers
def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
finished_set: set[str]) -> None:
for req_id in req_ids or ():
new_count = remaining_count_dict[req_id] - 1
if new_count == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
else:
remaining_count_dict[req_id] = new_count
finished_sending = set[str]() finished_sending = set[str]()
finished_recving = set[str]() finished_recving = set[str]()
for output in outputs: for output in outputs:
# update finished_sending update_finished_set(output.finished_sending,
for req_id in output.finished_sending or []: self._send_remaining_count, finished_sending)
new_count = self._send_remaining_count[req_id] - 1 update_finished_set(output.finished_recving,
if new_count == 0: self._recv_remaining_count, finished_recving)
# got response from all workers, report back to scheduler
finished_sending.add(req_id)
del self._send_remaining_count[req_id]
else:
self._send_remaining_count[req_id] = new_count
# update finished_recving
for req_id in output.finished_recving or []:
new_count = self._recv_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_recving.add(req_id)
del self._recv_remaining_count[req_id]
else:
self._recv_remaining_count[req_id] = new_count
# select output of the worker specified by output_rank # select output of the worker specified by output_rank
output = outputs[self.output_rank] output = outputs[self.output_rank]

View File

@ -1539,10 +1539,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata, attn_metadata,
) )
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
self.eplb_step() self.eplb_step()
return ModelRunnerOutput( return ModelRunnerOutput(

View File

@ -338,6 +338,10 @@ class Worker(WorkerBase):
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending output.finished_sending = finished_sending
output.finished_recving = finished_recving output.finished_recving = finished_recving
# Clear KVConnector state for this step.
get_kv_transfer_group().clear_connector_metadata()
# with a connector, the scheduler expects output from all workers # with a connector, the scheduler expects output from all workers
return output return output