mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:05:02 +08:00
[BugFix]: Properly set engine_id when using multi connector (#19487)
Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: leiyiming <leiyiming@kingsoft.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
332d4cb17b
commit
cd587c93ef
@ -76,6 +76,9 @@ class TestSharedStorageConnector(SharedStorageConnector):
|
|||||||
return attr
|
return attr
|
||||||
|
|
||||||
|
|
||||||
|
# This relies on "fork" multiprocessing method being used.
|
||||||
|
# It's the default but vLLM may fall back to spawn if for example CUDA
|
||||||
|
# is already initialized.
|
||||||
KVConnectorFactory.register_connector("TestSharedStorageConnector",
|
KVConnectorFactory.register_connector("TestSharedStorageConnector",
|
||||||
TestSharedStorageConnector.__module__,
|
TestSharedStorageConnector.__module__,
|
||||||
TestSharedStorageConnector.__name__)
|
TestSharedStorageConnector.__name__)
|
||||||
|
|||||||
@ -166,8 +166,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._hand_shake_latency = hand_shake_latency
|
self._hand_shake_latency = hand_shake_latency
|
||||||
|
|
||||||
def _nixl_handshake(self, host: str, port: int,
|
def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
|
||||||
remote_tp_size: int) -> dict[int, str]:
|
expected_engine_id: str) -> dict[int, str]:
|
||||||
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
|
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
|
||||||
time.sleep(self._hand_shake_latency)
|
time.sleep(self._hand_shake_latency)
|
||||||
# These should've been done in register_kv_caches(), called by
|
# These should've been done in register_kv_caches(), called by
|
||||||
@ -177,6 +177,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
|||||||
self.num_blocks = 1
|
self.num_blocks = 1
|
||||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||||
|
|
||||||
|
assert expected_engine_id == self.REMOTE_ENGINE_ID
|
||||||
|
|
||||||
remote_agent_name = self.add_remote_agent(
|
remote_agent_name = self.add_remote_agent(
|
||||||
NixlAgentMetadata(
|
NixlAgentMetadata(
|
||||||
engine_id=self.REMOTE_ENGINE_ID,
|
engine_id=self.REMOTE_ENGINE_ID,
|
||||||
|
|||||||
@ -47,7 +47,10 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
assert ktcs is not None
|
assert ktcs is not None
|
||||||
for ktc in ktcs:
|
for ktc in ktcs:
|
||||||
temp_config = copy.copy(vllm_config)
|
temp_config = copy.copy(vllm_config)
|
||||||
temp_config.kv_transfer_config = KVTransferConfig(**ktc)
|
engine_id = ktc.get("engine_id",
|
||||||
|
vllm_config.kv_transfer_config.engine_id)
|
||||||
|
temp_config.kv_transfer_config = KVTransferConfig(
|
||||||
|
**ktc, engine_id=engine_id)
|
||||||
self._connectors.append(
|
self._connectors.append(
|
||||||
KVConnectorFactory.create_connector_v1(temp_config, role))
|
KVConnectorFactory.create_connector_v1(temp_config, role))
|
||||||
|
|
||||||
@ -187,7 +190,7 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
async_saves += 1
|
async_saves += 1
|
||||||
if txfer_params is not None:
|
if txfer_params is not None:
|
||||||
if kv_txfer_params is not None:
|
if kv_txfer_params is not None:
|
||||||
#TODO we can probably change this to merge the dicts here,
|
# TODO we can probably change this to merge the dicts here,
|
||||||
# checking for key clashes.
|
# checking for key clashes.
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Only one connector can produce KV transfer params")
|
"Only one connector can produce KV transfer params")
|
||||||
|
|||||||
@ -488,8 +488,13 @@ class NixlConnectorWorker:
|
|||||||
"Connection listener got unexpected message %s", msg)
|
"Connection listener got unexpected message %s", msg)
|
||||||
sock.send_multipart((identity, b"", encoded_data))
|
sock.send_multipart((identity, b"", encoded_data))
|
||||||
|
|
||||||
def _nixl_handshake(self, host: str, port: int,
|
def _nixl_handshake(
|
||||||
remote_tp_size: int) -> dict[int, str]:
|
self,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
remote_tp_size: int,
|
||||||
|
expected_engine_id: str,
|
||||||
|
) -> dict[int, str]:
|
||||||
"""Do a NIXL handshake with a remote instance."""
|
"""Do a NIXL handshake with a remote instance."""
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@ -498,26 +503,6 @@ class NixlConnectorWorker:
|
|||||||
# a hack to keep us moving. We will switch when moving to etcd
|
# a hack to keep us moving. We will switch when moving to etcd
|
||||||
# or where we have a single ZMQ socket in the scheduler.
|
# or where we have a single ZMQ socket in the scheduler.
|
||||||
|
|
||||||
def handshake(path: str, rank: int) -> str:
|
|
||||||
# Send query for the request.
|
|
||||||
with zmq_ctx(zmq.REQ, path) as sock:
|
|
||||||
sock.send(GET_META_MSG)
|
|
||||||
metadata_bytes = sock.recv()
|
|
||||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
|
||||||
metadata = decoder.decode(metadata_bytes)
|
|
||||||
got_metadata_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Register Remote agent.
|
|
||||||
remote_agent_name = self.add_remote_agent(
|
|
||||||
metadata, rank, remote_tp_size)
|
|
||||||
setup_agent_time = time.perf_counter()
|
|
||||||
|
|
||||||
logger.debug("NIXL handshake: get metadata took: %s",
|
|
||||||
got_metadata_time - start_time)
|
|
||||||
logger.debug("NIXL handshake: add agent took: %s",
|
|
||||||
setup_agent_time - got_metadata_time)
|
|
||||||
return remote_agent_name
|
|
||||||
|
|
||||||
# Handshake only with the remote TP rank that current local rank will
|
# Handshake only with the remote TP rank that current local rank will
|
||||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||||
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
|
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
|
||||||
@ -525,8 +510,32 @@ class NixlConnectorWorker:
|
|||||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
||||||
logger.debug("Querying metadata on path: %s at remote rank %s", path,
|
logger.debug("Querying metadata on path: %s at remote rank %s", path,
|
||||||
p_remote_rank)
|
p_remote_rank)
|
||||||
|
|
||||||
|
# Send query for the request.
|
||||||
|
with zmq_ctx(zmq.REQ, path) as sock:
|
||||||
|
sock.send(GET_META_MSG)
|
||||||
|
metadata_bytes = sock.recv()
|
||||||
|
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||||
|
metadata = decoder.decode(metadata_bytes)
|
||||||
|
got_metadata_time = time.perf_counter()
|
||||||
|
logger.debug("NIXL handshake: get metadata took: %s",
|
||||||
|
got_metadata_time - start_time)
|
||||||
|
|
||||||
|
# Ensure engine id matches.
|
||||||
|
if metadata.engine_id != expected_engine_id:
|
||||||
|
raise RuntimeError(f"Remote NIXL agent engine ID mismatch. "
|
||||||
|
f"Expected {expected_engine_id},"
|
||||||
|
f"received {metadata.engine_id}.")
|
||||||
|
|
||||||
|
# Register Remote agent.
|
||||||
|
remote_agent_name = self.add_remote_agent(metadata, p_remote_rank,
|
||||||
|
remote_tp_size)
|
||||||
|
setup_agent_time = time.perf_counter()
|
||||||
|
logger.debug("NIXL handshake: add agent took: %s",
|
||||||
|
setup_agent_time - got_metadata_time)
|
||||||
|
|
||||||
# Remote rank -> agent name.
|
# Remote rank -> agent name.
|
||||||
return {p_remote_rank: handshake(path, p_remote_rank)}
|
return {p_remote_rank: remote_agent_name}
|
||||||
|
|
||||||
def _background_nixl_handshake(self, req_id: str,
|
def _background_nixl_handshake(self, req_id: str,
|
||||||
remote_engine_id: EngineId, meta: ReqMeta):
|
remote_engine_id: EngineId, meta: ReqMeta):
|
||||||
@ -535,7 +544,7 @@ class NixlConnectorWorker:
|
|||||||
if fut is None:
|
if fut is None:
|
||||||
fut = self._handshake_initiation_executor.submit(
|
fut = self._handshake_initiation_executor.submit(
|
||||||
self._nixl_handshake, meta.remote_host, meta.remote_port,
|
self._nixl_handshake, meta.remote_host, meta.remote_port,
|
||||||
meta.tp_size)
|
meta.tp_size, remote_engine_id)
|
||||||
self._handshake_futures[remote_engine_id] = fut
|
self._handshake_futures[remote_engine_id] = fut
|
||||||
|
|
||||||
def done_callback(f: Future[dict[int, str]], eid=remote_engine_id):
|
def done_callback(f: Future[dict[int, str]], eid=remote_engine_id):
|
||||||
@ -738,10 +747,10 @@ class NixlConnectorWorker:
|
|||||||
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
||||||
return self._remote_agents[engine_id][remote_tp_rank]
|
return self._remote_agents[engine_id][remote_tp_rank]
|
||||||
|
|
||||||
if engine_id in self._tp_size:
|
if engine_id not in self._tp_size:
|
||||||
assert self._tp_size[engine_id] == remote_tp_size
|
|
||||||
else:
|
|
||||||
self._tp_size[engine_id] = remote_tp_size
|
self._tp_size[engine_id] = remote_tp_size
|
||||||
|
else:
|
||||||
|
assert self._tp_size[engine_id] == remote_tp_size
|
||||||
# We may eventually enable this after asserting equality in cache
|
# We may eventually enable this after asserting equality in cache
|
||||||
# layout and close outputs.
|
# layout and close outputs.
|
||||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user