mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 02:46:58 +08:00
init
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
5d45b77124
commit
684c9b7b6d
@ -525,11 +525,15 @@ class NixlConnectorWorker:
|
|||||||
if remote_tp_size is None:
|
if remote_tp_size is None:
|
||||||
assert remote_engine_id is not None
|
assert remote_engine_id is not None
|
||||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||||
assert self.tp_size % remote_tp_size == 0, (
|
if self.tp_size >= remote_tp_size:
|
||||||
f"Local tensor parallel size {self.tp_size} is not divisible "
|
assert self.tp_size % remote_tp_size == 0, (
|
||||||
f"by remote tensor parallel size {remote_tp_size}."
|
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||||
)
|
f"by remote tensor parallel size {remote_tp_size}."
|
||||||
return self.tp_size // remote_tp_size
|
)
|
||||||
|
return self.tp_size // remote_tp_size
|
||||||
|
else:
|
||||||
|
# P TP > D TP case, return the ratio as negative
|
||||||
|
return -remote_tp_size // self.tp_size
|
||||||
|
|
||||||
def is_kv_replicated(self, engine_id: EngineId) -> bool:
|
def is_kv_replicated(self, engine_id: EngineId) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -538,22 +542,29 @@ class NixlConnectorWorker:
|
|||||||
"""
|
"""
|
||||||
tp_size = self.remote_tp_size[engine_id]
|
tp_size = self.remote_tp_size[engine_id]
|
||||||
return tp_size // self.total_num_kv_heads >= 1
|
return tp_size // self.total_num_kv_heads >= 1
|
||||||
|
|
||||||
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
|
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
|
||||||
# MLA is always replicated as the hidden dim can't be split.
|
# MLA is always replicated as the hidden dim can't be split.
|
||||||
return self.is_mla or self.is_kv_replicated(remote_engine_id)
|
# TODO docs
|
||||||
|
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||||
|
return self.is_mla or self.is_kv_replicated(remote_engine_id) or self.tp_size < remote_tp_size
|
||||||
|
|
||||||
def get_target_remote_rank(
|
def get_target_remote_ranks(
|
||||||
self,
|
self,
|
||||||
remote_engine_id: Optional[EngineId] = None,
|
remote_engine_id: Optional[EngineId] = None,
|
||||||
remote_tp_size: Optional[int] = None,
|
remote_tp_size: Optional[int] = None,
|
||||||
) -> int:
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Get the remote TP rank (on P) that the current local TP rank
|
Get the remote TP rank (on P) that the current local TP rank
|
||||||
(on D) will read from.
|
(on D) will read from.
|
||||||
"""
|
"""
|
||||||
tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size)
|
tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size)
|
||||||
return self.tp_rank // tp_ratio
|
if tp_ratio > 0:
|
||||||
|
return [self.tp_rank // tp_ratio]
|
||||||
|
else:
|
||||||
|
# P TP > D TP case, D reads from |tp_ratio| remote workers.
|
||||||
|
tp_ratio = -tp_ratio
|
||||||
|
return [self.tp_rank*tp_ratio + i for i in range(tp_ratio)]
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||||
if NixlWrapper is None:
|
if NixlWrapper is None:
|
||||||
@ -638,8 +649,8 @@ class NixlConnectorWorker:
|
|||||||
self.copy_blocks: Optional[CopyBlocksOp] = None
|
self.copy_blocks: Optional[CopyBlocksOp] = None
|
||||||
|
|
||||||
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
||||||
# rank will still only pull from a single remote TP worker.
|
# rank may pull from multiple remote TP workers.
|
||||||
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
|
self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = defaultdict(dict)
|
||||||
|
|
||||||
# Number of NIXL regions. Currently one region per cache
|
# Number of NIXL regions. Currently one region per cache
|
||||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||||
@ -649,7 +660,8 @@ class NixlConnectorWorker:
|
|||||||
# nixl_prepped_dlist_handle.
|
# nixl_prepped_dlist_handle.
|
||||||
self.src_xfer_side_handle: int = 0
|
self.src_xfer_side_handle: int = 0
|
||||||
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
||||||
self.dst_xfer_side_handles: dict[EngineId, int] = {}
|
# TODO do I need tp_Ratio of this?
|
||||||
|
self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(dict)
|
||||||
|
|
||||||
# Map of engine_id -> num_blocks. All ranks in the same deployment will
|
# Map of engine_id -> num_blocks. All ranks in the same deployment will
|
||||||
# have the same number of blocks.
|
# have the same number of blocks.
|
||||||
@ -756,51 +768,52 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
# NOTE(rob): we need each rank to have a unique port. This is
|
# When target instance TP > local TP, we need to perform multiple
|
||||||
# a hack to keep us moving. We will switch when moving to etcd
|
# handshakes. Do it in a single background job for simplicity.
|
||||||
# or where we have a single ZMQ socket in the scheduler.
|
# Regardless, only handshake with the remote TP rank(s) that current
|
||||||
|
# local rank will read from. Note that With homogeneous TP,
|
||||||
# Handshake only with the remote TP rank that current local rank will
|
# this happens to be the same single rank_i.
|
||||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
p_remote_ranks = self.kv_info.get_target_remote_ranks(
|
||||||
p_remote_rank = self.kv_info.get_target_remote_rank(
|
|
||||||
remote_tp_size=remote_tp_size
|
remote_tp_size=remote_tp_size
|
||||||
)
|
)
|
||||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
remote_rank_to_agent_name = {}
|
||||||
logger.debug(
|
for remote_rank in p_remote_ranks:
|
||||||
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
|
path = make_zmq_path("tcp", host, port + remote_rank)
|
||||||
)
|
logger.warning(
|
||||||
|
"Querying metadata on path: %s at remote rank %s", path, remote_rank
|
||||||
# Send query for the request.
|
|
||||||
with zmq_ctx(zmq.REQ, path) as sock:
|
|
||||||
sock.send(GET_META_MSG)
|
|
||||||
metadata_bytes = sock.recv()
|
|
||||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
|
||||||
metadata = decoder.decode(metadata_bytes)
|
|
||||||
got_metadata_time = time.perf_counter()
|
|
||||||
logger.debug(
|
|
||||||
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure engine id matches.
|
# Send query for the request.
|
||||||
if metadata.engine_id != expected_engine_id:
|
with zmq_ctx(zmq.REQ, path) as sock:
|
||||||
raise RuntimeError(
|
sock.send(GET_META_MSG)
|
||||||
f"Remote NIXL agent engine ID mismatch. "
|
metadata_bytes = sock.recv()
|
||||||
f"Expected {expected_engine_id},"
|
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||||
f"received {metadata.engine_id}."
|
metadata = decoder.decode(metadata_bytes)
|
||||||
|
got_metadata_time = time.perf_counter()
|
||||||
|
logger.debug(
|
||||||
|
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register Remote agent.
|
# Ensure engine id matches.
|
||||||
remote_agent_name = self.add_remote_agent(
|
if metadata.engine_id != expected_engine_id:
|
||||||
metadata, p_remote_rank, remote_tp_size
|
raise RuntimeError(
|
||||||
)
|
f"Remote NIXL agent engine ID mismatch. "
|
||||||
setup_agent_time = time.perf_counter()
|
f"Expected {expected_engine_id},"
|
||||||
logger.debug(
|
f"received {metadata.engine_id}."
|
||||||
"NIXL handshake: add agent took: %s",
|
)
|
||||||
setup_agent_time - got_metadata_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remote rank -> agent name.
|
# Register Remote agent.
|
||||||
return {p_remote_rank: remote_agent_name}
|
remote_agent_name = self.add_remote_agent(
|
||||||
|
metadata, remote_rank, remote_tp_size
|
||||||
|
)
|
||||||
|
setup_agent_time = time.perf_counter()
|
||||||
|
logger.debug(
|
||||||
|
"NIXL handshake: add agent took: %s",
|
||||||
|
setup_agent_time - got_metadata_time,
|
||||||
|
)
|
||||||
|
remote_rank_to_agent_name[remote_rank] = remote_agent_name
|
||||||
|
|
||||||
|
return remote_rank_to_agent_name
|
||||||
|
|
||||||
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
|
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||||
"""
|
"""
|
||||||
@ -944,7 +957,7 @@ class NixlConnectorWorker:
|
|||||||
assert len(self.block_len_per_layer) == len(seen_base_addresses)
|
assert len(self.block_len_per_layer) == len(seen_base_addresses)
|
||||||
assert self.num_blocks != 0
|
assert self.num_blocks != 0
|
||||||
|
|
||||||
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
|
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
|
||||||
self.num_regions = len(caches_data)
|
self.num_regions = len(caches_data)
|
||||||
self.num_layers = len(xfer_buffers.keys())
|
self.num_layers = len(xfer_buffers.keys())
|
||||||
|
|
||||||
@ -1033,7 +1046,7 @@ class NixlConnectorWorker:
|
|||||||
metadata = NixlAgentMetadata(
|
metadata = NixlAgentMetadata(
|
||||||
engine_id=self.engine_id,
|
engine_id=self.engine_id,
|
||||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
|
||||||
num_blocks=self.num_blocks,
|
num_blocks=self.num_blocks,
|
||||||
block_lens=self.block_len_per_layer,
|
block_lens=self.block_len_per_layer,
|
||||||
attn_backend_name=self.backend_name,
|
attn_backend_name=self.backend_name,
|
||||||
@ -1096,7 +1109,7 @@ class NixlConnectorWorker:
|
|||||||
engine_id = nixl_agent_meta.engine_id
|
engine_id = nixl_agent_meta.engine_id
|
||||||
# TODO re-evaluate refreshing for scaling/recovery
|
# TODO re-evaluate refreshing for scaling/recovery
|
||||||
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
||||||
logger.debug(
|
logger.warning(
|
||||||
"Remote agent with engine_id %s and rank"
|
"Remote agent with engine_id %s and rank"
|
||||||
"%s already exchanged metadata, skip handshake.",
|
"%s already exchanged metadata, skip handshake.",
|
||||||
engine_id,
|
engine_id,
|
||||||
@ -1114,6 +1127,7 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||||
replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id)
|
replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id)
|
||||||
|
print("REPLICATES KV CACHE", replicates_kv_cache, "\n")
|
||||||
|
|
||||||
# Create dst descs and xfer side handles. TP workers have same #blocks
|
# Create dst descs and xfer side handles. TP workers have same #blocks
|
||||||
# so we only register once per engine_id.
|
# so we only register once per engine_id.
|
||||||
@ -1121,11 +1135,12 @@ class NixlConnectorWorker:
|
|||||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||||
|
|
||||||
# Keep track of remote agent kv caches base addresses.
|
# Keep track of remote agent kv caches base addresses.
|
||||||
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
|
self.kv_caches_base_addr[engine_id][self.tp_rank] = nixl_agent_meta.kv_caches_base_addr
|
||||||
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
|
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
|
||||||
|
|
||||||
# Number of D TP workers reading from a single P TP worker. This is
|
# Number of D TP workers reading from a single P TP worker. This is
|
||||||
# 1 when P and D `--tensor-parallel-size` match.
|
# 1 when P and D `--tensor-parallel-size` match. If P TP > D TP,
|
||||||
|
# we don't need to use this for spliting the remote kv cache.
|
||||||
tp_ratio = self.kv_info.tp_ratio(engine_id)
|
tp_ratio = self.kv_info.tp_ratio(engine_id)
|
||||||
|
|
||||||
### Register remote agent memory regions
|
### Register remote agent memory regions
|
||||||
@ -1137,7 +1152,8 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Register all remote blocks, but only the corresponding kv heads.
|
# Register all remote blocks, but only the corresponding kv heads.
|
||||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
# TODO
|
||||||
|
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2
|
||||||
rank_offset = (
|
rank_offset = (
|
||||||
self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
|
self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
|
||||||
)
|
)
|
||||||
@ -1168,7 +1184,7 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Register with NIXL.
|
# Register with NIXL.
|
||||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||||
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
self.dst_xfer_side_handles[engine_id][remote_tp_rank] = self.nixl_wrapper.prep_xfer_dlist(
|
||||||
remote_agent_name, descs
|
remote_agent_name, descs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1189,7 +1205,6 @@ class NixlConnectorWorker:
|
|||||||
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
|
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
|
||||||
|
|
||||||
tp_ratio = self.kv_info.tp_ratio(remote_engine_id)
|
tp_ratio = self.kv_info.tp_ratio(remote_engine_id)
|
||||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
|
||||||
assert not self._use_pallas or tp_ratio == 1, (
|
assert not self._use_pallas or tp_ratio == 1, (
|
||||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||||
)
|
)
|
||||||
@ -1217,21 +1232,20 @@ class NixlConnectorWorker:
|
|||||||
if self._use_flashinfer:
|
if self._use_flashinfer:
|
||||||
# With flashinfer, KV are sent in the same message.
|
# With flashinfer, KV are sent in the same message.
|
||||||
remote_block_size //= 2
|
remote_block_size //= 2
|
||||||
|
# TODO add asserts for P TP > D TP
|
||||||
|
# assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
||||||
|
# "Remote P worker KV layer cache must be of shape [2, N, "
|
||||||
|
# "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||||
|
# )
|
||||||
|
|
||||||
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
# assert self.block_size == remote_block_size, (
|
||||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
# "Remote P worker with different page/block size is not supported "
|
||||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
# f"{self.block_size=}, {remote_block_size=}"
|
||||||
)
|
# )
|
||||||
|
|
||||||
assert self.block_size == remote_block_size, (
|
# # TP workers have same #blocks.
|
||||||
"Remote P worker with different page/block size is not supported "
|
# assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
|
||||||
f"{self.block_size=}, {remote_block_size=}"
|
# assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
|
||||||
)
|
|
||||||
|
|
||||||
# TP workers have same #blocks.
|
|
||||||
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
|
|
||||||
|
|
||||||
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
|
|
||||||
|
|
||||||
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
|
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
|
||||||
"""copy recved kv from host buffer to device."""
|
"""copy recved kv from host buffer to device."""
|
||||||
@ -1431,17 +1445,24 @@ class NixlConnectorWorker:
|
|||||||
self._reqs_to_send[req_id] = expiration_time
|
self._reqs_to_send[req_id] = expiration_time
|
||||||
|
|
||||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||||
logger.debug(
|
remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id)
|
||||||
"Remote agent %s available, calling _read_blocks for req %s",
|
# D may perform multiple reads from different remote ranks.
|
||||||
meta.remote_engine_id,
|
for remote_rank in remote_ranks:
|
||||||
req_id,
|
logger.debug(
|
||||||
)
|
"Remote agent %s available, calling _read_blocks"
|
||||||
self._read_blocks(
|
" on remote rank %s for req %s",
|
||||||
request_id=req_id,
|
meta.remote_engine_id,
|
||||||
dst_engine_id=meta.remote_engine_id,
|
remote_rank,
|
||||||
local_block_ids=meta.local_block_ids,
|
req_id,
|
||||||
remote_block_ids=meta.remote_block_ids,
|
)
|
||||||
)
|
# TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS!
|
||||||
|
self._read_blocks(
|
||||||
|
request_id=req_id,
|
||||||
|
dst_engine_id=meta.remote_engine_id,
|
||||||
|
local_block_ids=meta.local_block_ids,
|
||||||
|
remote_block_ids=meta.remote_block_ids,
|
||||||
|
remote_rank=remote_rank,
|
||||||
|
)
|
||||||
|
|
||||||
def _read_blocks(
|
def _read_blocks(
|
||||||
self,
|
self,
|
||||||
@ -1449,6 +1470,7 @@ class NixlConnectorWorker:
|
|||||||
remote_block_ids: list[int],
|
remote_block_ids: list[int],
|
||||||
dst_engine_id: str,
|
dst_engine_id: str,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
remote_rank: int,
|
||||||
):
|
):
|
||||||
# NOTE(rob): having the staging blocks be on the READER side is
|
# NOTE(rob): having the staging blocks be on the READER side is
|
||||||
# not going to work well (since we will have to call rearrange tensors).
|
# not going to work well (since we will have to call rearrange tensors).
|
||||||
@ -1462,14 +1484,14 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Number of D TP workers that will read from dst P. Propagate tp_ratio
|
# Number of D TP workers that will read from dst P. Propagate tp_ratio
|
||||||
# on notification so that dst worker can wait before freeing blocks.
|
# on notification so that dst worker can wait before freeing blocks.
|
||||||
tp_ratio = self.kv_info.tp_ratio(dst_engine_id)
|
# Cap to 1 when P TP > D TP: only a single rank will read from remote.
|
||||||
|
tp_ratio = max(1, self.kv_info.tp_ratio(dst_engine_id))
|
||||||
notif_id = f"{request_id}:{tp_ratio}".encode()
|
notif_id = f"{request_id}:{tp_ratio}".encode()
|
||||||
|
|
||||||
# Full prefix cache hit: do not need to read remote blocks,
|
# Full prefix cache hit: do not need to read remote blocks,
|
||||||
# just notify P worker that we have the blocks we need.
|
# just notify P worker(s) that we have the blocks we need.
|
||||||
num_local_blocks = len(local_block_ids)
|
num_local_blocks = len(local_block_ids)
|
||||||
if num_local_blocks == 0:
|
if num_local_blocks == 0:
|
||||||
remote_rank = self.kv_info.get_target_remote_rank(dst_engine_id)
|
|
||||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||||
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
||||||
return
|
return
|
||||||
@ -1482,7 +1504,7 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Get side handles.
|
# Get side handles.
|
||||||
local_xfer_side_handle = self.src_xfer_side_handle
|
local_xfer_side_handle = self.src_xfer_side_handle
|
||||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
|
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank]
|
||||||
|
|
||||||
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
||||||
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
|
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
|
||||||
@ -1571,6 +1593,7 @@ class NixlConnectorWorker:
|
|||||||
assert self.num_layers == self.num_regions
|
assert self.num_layers == self.num_regions
|
||||||
region_ids = np.arange(layer_idx, layer_idx + 1)
|
region_ids = np.arange(layer_idx, layer_idx + 1)
|
||||||
|
|
||||||
|
# TODO can this vary?
|
||||||
num_blocks = self.dst_num_blocks[engine_id]
|
num_blocks = self.dst_num_blocks[engine_id]
|
||||||
|
|
||||||
# Compute the desc ids for each block.
|
# Compute the desc ids for each block.
|
||||||
@ -1617,8 +1640,9 @@ class NixlConnectorWorker:
|
|||||||
if self.src_xfer_side_handle:
|
if self.src_xfer_side_handle:
|
||||||
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
|
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
|
||||||
self.src_xfer_side_handle = 0
|
self.src_xfer_side_handle = 0
|
||||||
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
|
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
|
||||||
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
|
for dst_xfer_side_handle in dst_xfer_side_handles.values():
|
||||||
|
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
|
||||||
self.dst_xfer_side_handles.clear()
|
self.dst_xfer_side_handles.clear()
|
||||||
for remote_agents in self._remote_agents.values():
|
for remote_agents in self._remote_agents.values():
|
||||||
for agent_name in remote_agents.values():
|
for agent_name in remote_agents.values():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user