[Bugfix] Set KVTransferConfig.engine_id in post_init (#18576)

Signed-off-by: Linkun Chen <github@lkchen.net>
This commit is contained in:
lkchen 2025-05-22 19:54:42 -07:00 committed by GitHub
parent 93ecb8139c
commit e44d8ce8c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 1 deletions

View File

@ -239,3 +239,11 @@ def get_connector_events() -> dict[str, list[str]]:
print(f"[ERROR] Could not read connector events for {name}: {e}")
return connector_events
def test_engine_id_conflict():
configs = [KVTransferConfig() for _ in range(2)]
ids = [config.engine_id for config in configs]
assert ids[0] != ids[1], (
"Engine IDs should be different for different configs. "
f"Got {ids}")

View File

@ -3495,7 +3495,7 @@ class KVTransferConfig:
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
"""
engine_id: str = str(uuid.uuid4())
engine_id: Optional[str] = None
"""The engine id for KV transfers."""
kv_buffer_device: Optional[str] = "cuda"
@ -3552,6 +3552,9 @@ class KVTransferConfig:
return hash_str
def __post_init__(self) -> None:
if self.engine_id is None:
self.engine_id = str(uuid.uuid4())
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
f"Supported roles are {get_args(KVRole)}")

View File

@ -537,6 +537,7 @@ class NixlConnectorWorker:
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
engine_id = nixl_agent_meta.engine_id
assert engine_id != self.engine_id, "Conflict engine id found!"
if engine_id in self._remote_agents:
return