[KVConnector] Always call connector clear_metadata() at end of step (#20756)

Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: David Ben-David <sdavidbd@gmail.com>
This commit is contained in:
Nick Hill 2025-07-10 22:37:27 +01:00 committed by GitHub
parent fdadb6f43a
commit 574ad60db9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 25 additions and 26 deletions

View File

@ -57,7 +57,7 @@ class KVConnectorRole(enum.Enum):
WORKER = 1
class KVConnectorMetadata:
class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector.
@ -71,7 +71,7 @@ class KVConnectorBase_V1(ABC):
logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design.")
self._connector_metadata = KVConnectorMetadata()
self._connector_metadata: Optional[KVConnectorMetadata] = None
self._vllm_config = vllm_config
self._role = role
@ -102,7 +102,7 @@ class KVConnectorBase_V1(ABC):
This function should be called by the model runner every time
after the model execution.
"""
self._connector_metadata = KVConnectorMetadata()
self._connector_metadata = None
def _get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata.
@ -112,6 +112,9 @@ class KVConnectorBase_V1(ABC):
Returns:
ConnectorMetadata: the connector metadata.
"""
# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):

View File

@ -250,28 +250,24 @@ class MultiprocExecutor(Executor):
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers
def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
finished_set: set[str]) -> None:
for req_id in req_ids or ():
new_count = remaining_count_dict[req_id] - 1
if new_count == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
else:
remaining_count_dict[req_id] = new_count
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
# update finished_sending
for req_id in output.finished_sending or []:
new_count = self._send_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_sending.add(req_id)
del self._send_remaining_count[req_id]
else:
self._send_remaining_count[req_id] = new_count
# update finished_recving
for req_id in output.finished_recving or []:
new_count = self._recv_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_recving.add(req_id)
del self._recv_remaining_count[req_id]
else:
self._recv_remaining_count[req_id] = new_count
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving)
# select output of the worker specified by output_rank
output = outputs[self.output_rank]

View File

@ -1539,10 +1539,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata,
)
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
self.eplb_step()
return ModelRunnerOutput(

View File

@ -338,6 +338,10 @@ class Worker(WorkerBase):
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
# Clear KVConnector state for this step.
get_kv_transfer_group().clear_connector_metadata()
# with a connector, the scheduler expects output from all workers
return output