mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:34:59 +08:00
[KVConnector] Always call connector clear_metadata() at end of step (#20756)
Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: David Ben-David <sdavidbd@gmail.com>
This commit is contained in:
parent
fdadb6f43a
commit
574ad60db9
@ -57,7 +57,7 @@ class KVConnectorRole(enum.Enum):
|
|||||||
WORKER = 1
|
WORKER = 1
|
||||||
|
|
||||||
|
|
||||||
class KVConnectorMetadata:
|
class KVConnectorMetadata(ABC): # noqa: B024
|
||||||
"""
|
"""
|
||||||
Abstract Metadata used to communicate between the
|
Abstract Metadata used to communicate between the
|
||||||
Scheduler KVConnector and Worker KVConnector.
|
Scheduler KVConnector and Worker KVConnector.
|
||||||
@ -71,7 +71,7 @@ class KVConnectorBase_V1(ABC):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Initializing KVConnectorBase_V1. This API is experimental and "
|
"Initializing KVConnectorBase_V1. This API is experimental and "
|
||||||
"subject to change in the future as we iterate the design.")
|
"subject to change in the future as we iterate the design.")
|
||||||
self._connector_metadata = KVConnectorMetadata()
|
self._connector_metadata: Optional[KVConnectorMetadata] = None
|
||||||
self._vllm_config = vllm_config
|
self._vllm_config = vllm_config
|
||||||
self._role = role
|
self._role = role
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ class KVConnectorBase_V1(ABC):
|
|||||||
This function should be called by the model runner every time
|
This function should be called by the model runner every time
|
||||||
after the model execution.
|
after the model execution.
|
||||||
"""
|
"""
|
||||||
self._connector_metadata = KVConnectorMetadata()
|
self._connector_metadata = None
|
||||||
|
|
||||||
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||||
"""Get the connector metadata.
|
"""Get the connector metadata.
|
||||||
@ -112,6 +112,9 @@ class KVConnectorBase_V1(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
ConnectorMetadata: the connector metadata.
|
ConnectorMetadata: the connector metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Should only be called while set to valid metadata.
|
||||||
|
assert self._connector_metadata is not None
|
||||||
return self._connector_metadata
|
return self._connector_metadata
|
||||||
|
|
||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
|
|||||||
@ -250,28 +250,24 @@ class MultiprocExecutor(Executor):
|
|||||||
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
|
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
|
||||||
# aggregate finished_sending, finished_recving from all workers
|
# aggregate finished_sending, finished_recving from all workers
|
||||||
|
|
||||||
|
def update_finished_set(req_ids: Optional[set[str]],
|
||||||
|
remaining_count_dict: dict[str, int],
|
||||||
|
finished_set: set[str]) -> None:
|
||||||
|
for req_id in req_ids or ():
|
||||||
|
new_count = remaining_count_dict[req_id] - 1
|
||||||
|
if new_count == 0:
|
||||||
|
finished_set.add(req_id)
|
||||||
|
del remaining_count_dict[req_id]
|
||||||
|
else:
|
||||||
|
remaining_count_dict[req_id] = new_count
|
||||||
|
|
||||||
finished_sending = set[str]()
|
finished_sending = set[str]()
|
||||||
finished_recving = set[str]()
|
finished_recving = set[str]()
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
# update finished_sending
|
update_finished_set(output.finished_sending,
|
||||||
for req_id in output.finished_sending or []:
|
self._send_remaining_count, finished_sending)
|
||||||
new_count = self._send_remaining_count[req_id] - 1
|
update_finished_set(output.finished_recving,
|
||||||
if new_count == 0:
|
self._recv_remaining_count, finished_recving)
|
||||||
# got response from all workers, report back to scheduler
|
|
||||||
finished_sending.add(req_id)
|
|
||||||
del self._send_remaining_count[req_id]
|
|
||||||
else:
|
|
||||||
self._send_remaining_count[req_id] = new_count
|
|
||||||
|
|
||||||
# update finished_recving
|
|
||||||
for req_id in output.finished_recving or []:
|
|
||||||
new_count = self._recv_remaining_count[req_id] - 1
|
|
||||||
if new_count == 0:
|
|
||||||
# got response from all workers, report back to scheduler
|
|
||||||
finished_recving.add(req_id)
|
|
||||||
del self._recv_remaining_count[req_id]
|
|
||||||
else:
|
|
||||||
self._recv_remaining_count[req_id] = new_count
|
|
||||||
|
|
||||||
# select output of the worker specified by output_rank
|
# select output of the worker specified by output_rank
|
||||||
output = outputs[self.output_rank]
|
output = outputs[self.output_rank]
|
||||||
|
|||||||
@ -1539,10 +1539,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
attn_metadata,
|
attn_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clear KVConnector state after all KVs are generated.
|
|
||||||
if has_kv_transfer_group():
|
|
||||||
get_kv_transfer_group().clear_connector_metadata()
|
|
||||||
|
|
||||||
self.eplb_step()
|
self.eplb_step()
|
||||||
|
|
||||||
return ModelRunnerOutput(
|
return ModelRunnerOutput(
|
||||||
|
|||||||
@ -338,6 +338,10 @@ class Worker(WorkerBase):
|
|||||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
output.finished_sending = finished_sending
|
output.finished_sending = finished_sending
|
||||||
output.finished_recving = finished_recving
|
output.finished_recving = finished_recving
|
||||||
|
|
||||||
|
# Clear KVConnector state for this step.
|
||||||
|
get_kv_transfer_group().clear_connector_metadata()
|
||||||
|
|
||||||
# with a connector, the scheduler expects output from all workers
|
# with a connector, the scheduler expects output from all workers
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user