mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 17:05:36 +08:00
[Bugfix] Set KVTransferConfig.engine_id in post_init (#18576)
Signed-off-by: Linkun Chen <github@lkchen.net>
This commit is contained in:
parent
93ecb8139c
commit
e44d8ce8c7
@ -239,3 +239,11 @@ def get_connector_events() -> dict[str, list[str]]:
|
|||||||
print(f"[ERROR] Could not read connector events for {name}: {e}")
|
print(f"[ERROR] Could not read connector events for {name}: {e}")
|
||||||
|
|
||||||
return connector_events
|
return connector_events
|
||||||
|
|
||||||
|
|
||||||
|
def test_engine_id_conflict():
|
||||||
|
configs = [KVTransferConfig() for _ in range(2)]
|
||||||
|
ids = [config.engine_id for config in configs]
|
||||||
|
assert ids[0] != ids[1], (
|
||||||
|
"Engine IDs should be different for different configs. "
|
||||||
|
f"Got {ids}")
|
||||||
|
|||||||
@ -3495,7 +3495,7 @@ class KVTransferConfig:
|
|||||||
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
engine_id: str = str(uuid.uuid4())
|
engine_id: Optional[str] = None
|
||||||
"""The engine id for KV transfers."""
|
"""The engine id for KV transfers."""
|
||||||
|
|
||||||
kv_buffer_device: Optional[str] = "cuda"
|
kv_buffer_device: Optional[str] = "cuda"
|
||||||
@ -3552,6 +3552,9 @@ class KVTransferConfig:
|
|||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
if self.engine_id is None:
|
||||||
|
self.engine_id = str(uuid.uuid4())
|
||||||
|
|
||||||
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
|
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
|
||||||
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
|
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
|
||||||
f"Supported roles are {get_args(KVRole)}")
|
f"Supported roles are {get_args(KVRole)}")
|
||||||
|
|||||||
@ -537,6 +537,7 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
|
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
|
||||||
engine_id = nixl_agent_meta.engine_id
|
engine_id = nixl_agent_meta.engine_id
|
||||||
|
assert engine_id != self.engine_id, "Conflict engine id found!"
|
||||||
if engine_id in self._remote_agents:
|
if engine_id in self._remote_agents:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user