mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 07:10:04 +08:00
[BugFix] Fix multi async save in MultiConnector (#18246)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
d3d91b6f71
commit
1db4f47f81
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import copy
|
import copy
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -21,9 +22,10 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...],
|
@dataclass
|
||||||
KVConnectorMetadata):
|
class MultiKVConnectorMetadata(KVConnectorMetadata):
|
||||||
pass
|
metadata: tuple[KVConnectorMetadata, ...]
|
||||||
|
extra_async_saves: Optional[dict[str, int]] = None
|
||||||
|
|
||||||
|
|
||||||
class MultiConnector(KVConnectorBase_V1):
|
class MultiConnector(KVConnectorBase_V1):
|
||||||
@ -54,6 +56,7 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
||||||
# finished per request. Not needed for async loads since we only allow
|
# finished per request. Not needed for async loads since we only allow
|
||||||
# a single connector to load.
|
# a single connector to load.
|
||||||
|
# Propagated from scheduler to worker side via the connector metadata.
|
||||||
self._extra_async_saves: dict[str, int] = {}
|
self._extra_async_saves: dict[str, int] = {}
|
||||||
|
|
||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
@ -66,7 +69,10 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
def bind_connector_metadata(
|
def bind_connector_metadata(
|
||||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||||
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
||||||
for c, cm in zip(self._connectors, connector_metadata):
|
if connector_metadata.extra_async_saves:
|
||||||
|
self._extra_async_saves.update(
|
||||||
|
connector_metadata.extra_async_saves)
|
||||||
|
for c, cm in zip(self._connectors, connector_metadata.metadata):
|
||||||
c.bind_connector_metadata(cm)
|
c.bind_connector_metadata(cm)
|
||||||
|
|
||||||
def clear_connector_metadata(self) -> None:
|
def clear_connector_metadata(self) -> None:
|
||||||
@ -152,8 +158,13 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
def build_connector_meta(
|
def build_connector_meta(
|
||||||
self,
|
self,
|
||||||
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
|
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
|
||||||
return MultiKVConnectorMetadata(
|
metadata = MultiKVConnectorMetadata(metadata=tuple(
|
||||||
c.build_connector_meta(scheduler_output) for c in self._connectors)
|
c.build_connector_meta(scheduler_output)
|
||||||
|
for c in self._connectors))
|
||||||
|
if self._extra_async_saves:
|
||||||
|
metadata.extra_async_saves = self._extra_async_saves
|
||||||
|
self._extra_async_saves = {}
|
||||||
|
return metadata
|
||||||
|
|
||||||
def request_finished(
|
def request_finished(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user