mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 13:44:58 +08:00
[Nixl] Minor refactor to handshake related metadata (#26410)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
be4445072c
commit
72f431e709
@ -565,8 +565,6 @@ class TestNixlHandshake:
|
|||||||
kv_cache_layout=mismatched_layout,
|
kv_cache_layout=mismatched_layout,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
|
||||||
# whole block is moved.
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
# mismatched layout is expected to fail
|
# mismatched layout is expected to fail
|
||||||
worker.add_remote_agent(meta, remote_tp_size=2)
|
worker.add_remote_agent(meta, remote_tp_size=2)
|
||||||
|
|||||||
@ -36,7 +36,6 @@ from vllm.distributed.parallel_state import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
)
|
)
|
||||||
from vllm.distributed.utils import divide
|
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -521,6 +520,72 @@ class NixlConnectorScheduler:
|
|||||||
class NixlConnectorWorker:
|
class NixlConnectorWorker:
|
||||||
"""Implementation of Worker side methods"""
|
"""Implementation of Worker side methods"""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TpKVTopology:
|
||||||
|
"""
|
||||||
|
Helper class for tensor parallel and KV topology information for
|
||||||
|
mapping between local and remote TP workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tp_size: int
|
||||||
|
tp_rank: int
|
||||||
|
remote_tp_size: dict[EngineId, int]
|
||||||
|
is_mla: bool
|
||||||
|
total_num_kv_heads: int
|
||||||
|
|
||||||
|
def tp_ratio(
|
||||||
|
self,
|
||||||
|
remote_tp_size: int,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the tensor parallel ratio between local and remote TP.
|
||||||
|
We can think of it as the number of local TP workers-per-remote TP
|
||||||
|
workers. Local workers will read from the same remote TP worker in
|
||||||
|
groups of size `tp_ratio`.
|
||||||
|
"""
|
||||||
|
assert self.tp_size % remote_tp_size == 0, (
|
||||||
|
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||||
|
f"by remote tensor parallel size {remote_tp_size}."
|
||||||
|
)
|
||||||
|
return self.tp_size // remote_tp_size
|
||||||
|
|
||||||
|
def tp_ratio_from_engine_id(
|
||||||
|
self,
|
||||||
|
remote_engine_id: EngineId,
|
||||||
|
) -> int:
|
||||||
|
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||||
|
return self.tp_ratio(remote_tp_size)
|
||||||
|
|
||||||
|
def is_kv_replicated(self, engine_id: EngineId) -> bool:
|
||||||
|
"""
|
||||||
|
Whether the KV cache is replicated across TP workers due to the
|
||||||
|
number of TP workers being greater than the number of KV heads.
|
||||||
|
"""
|
||||||
|
tp_size = self.remote_tp_size[engine_id]
|
||||||
|
return tp_size // self.total_num_kv_heads >= 1
|
||||||
|
|
||||||
|
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
|
||||||
|
# MLA is always replicated as the hidden dim can't be split.
|
||||||
|
return self.is_mla or self.is_kv_replicated(remote_engine_id)
|
||||||
|
|
||||||
|
def get_target_remote_rank(
|
||||||
|
self,
|
||||||
|
remote_tp_size: int,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get the remote TP rank (on P) that the current local TP rank
|
||||||
|
(on D) will read from.
|
||||||
|
"""
|
||||||
|
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||||
|
return self.tp_rank // tp_ratio
|
||||||
|
|
||||||
|
def get_target_remote_rank_from_engine_id(
|
||||||
|
self,
|
||||||
|
remote_engine_id: EngineId,
|
||||||
|
) -> int:
|
||||||
|
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||||
|
return self.get_target_remote_rank(remote_tp_size)
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||||
if NixlWrapper is None:
|
if NixlWrapper is None:
|
||||||
logger.error("NIXL is not available")
|
logger.error("NIXL is not available")
|
||||||
@ -534,6 +599,7 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
if vllm_config.kv_transfer_config is None:
|
if vllm_config.kv_transfer_config is None:
|
||||||
raise ValueError("kv_transfer_config must be set for NixlConnector")
|
raise ValueError("kv_transfer_config must be set for NixlConnector")
|
||||||
|
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
|
|
||||||
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
|
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||||
"backends", ["UCX"]
|
"backends", ["UCX"]
|
||||||
@ -654,7 +720,6 @@ class NixlConnectorWorker:
|
|||||||
# Protects _handshake_futures and _remote_agents.
|
# Protects _handshake_futures and _remote_agents.
|
||||||
self._handshake_lock = threading.RLock()
|
self._handshake_lock = threading.RLock()
|
||||||
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
@ -686,6 +751,14 @@ class NixlConnectorWorker:
|
|||||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||||
self.xfer_stats = NixlKVConnectorStats()
|
self.xfer_stats = NixlKVConnectorStats()
|
||||||
|
|
||||||
|
self.kv_topo = self.TpKVTopology(
|
||||||
|
tp_size=self.world_size,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
remote_tp_size=self._tp_size, # shared state
|
||||||
|
is_mla=self.use_mla,
|
||||||
|
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _nixl_handshake_listener(
|
def _nixl_handshake_listener(
|
||||||
metadata: NixlAgentMetadata,
|
metadata: NixlAgentMetadata,
|
||||||
@ -731,8 +804,7 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Handshake only with the remote TP rank that current local rank will
|
# Handshake only with the remote TP rank that current local rank will
|
||||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||||
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
|
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
|
||||||
p_remote_rank = self.tp_rank // tp_ratio
|
|
||||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
|
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
|
||||||
@ -989,13 +1061,11 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# TODO(mgoin): Hybrid memory allocator is currently disabled for
|
# TODO(mgoin): Hybrid memory allocator is currently disabled for
|
||||||
# models with local attention (Llama 4). Can remove this once enabled.
|
# models with local attention (Llama 4). Can remove this once enabled.
|
||||||
if self.vllm_config.model_config.hf_config.model_type == "llama4":
|
if self.model_config.hf_config.model_type == "llama4":
|
||||||
from transformers import Llama4TextConfig
|
from transformers import Llama4TextConfig
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(self.model_config.hf_text_config, Llama4TextConfig)
|
||||||
self.vllm_config.model_config.hf_text_config, Llama4TextConfig
|
llama4_config = self.model_config.hf_text_config
|
||||||
)
|
|
||||||
llama4_config = self.vllm_config.model_config.hf_text_config
|
|
||||||
no_rope_layers = llama4_config.no_rope_layers
|
no_rope_layers = llama4_config.no_rope_layers
|
||||||
chunk_size = llama4_config.attention_chunk_size
|
chunk_size = llama4_config.attention_chunk_size
|
||||||
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||||
@ -1078,107 +1148,51 @@ class NixlConnectorWorker:
|
|||||||
engine_id = nixl_agent_meta.engine_id
|
engine_id = nixl_agent_meta.engine_id
|
||||||
# TODO re-evaluate refreshing for scaling/recovery
|
# TODO re-evaluate refreshing for scaling/recovery
|
||||||
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
||||||
|
logger.debug(
|
||||||
|
"Remote agent with engine_id %s and rank"
|
||||||
|
"%s already exchanged metadata, skip handshake.",
|
||||||
|
engine_id,
|
||||||
|
remote_tp_rank,
|
||||||
|
)
|
||||||
return self._remote_agents[engine_id][remote_tp_rank]
|
return self._remote_agents[engine_id][remote_tp_rank]
|
||||||
|
|
||||||
|
### Register remote agent metadata
|
||||||
if engine_id not in self._tp_size:
|
if engine_id not in self._tp_size:
|
||||||
self._tp_size[engine_id] = remote_tp_size
|
self._tp_size[engine_id] = remote_tp_size
|
||||||
else:
|
|
||||||
assert self._tp_size[engine_id] == remote_tp_size
|
|
||||||
# TODO We may eventually want to skip enforcing the same attn backend.
|
|
||||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
|
||||||
|
|
||||||
remote_agent_name = self.nixl_wrapper.add_remote_agent(
|
remote_agent_name = self.nixl_wrapper.add_remote_agent(
|
||||||
nixl_agent_meta.agent_metadata
|
nixl_agent_meta.agent_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
# Number of D TP workers reading from a single P TP worker. This is
|
|
||||||
# 1 when P and D `--tensor-parallel-size` match.
|
|
||||||
tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id])
|
|
||||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
|
||||||
assert not self._use_pallas or tp_ratio == 1, (
|
|
||||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||||
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
|
replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id)
|
||||||
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
|
|
||||||
|
|
||||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
# Create dst descs and xfer side handles. TP workers have same #blocks
|
||||||
if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
|
# so we only register once per engine_id.
|
||||||
if (
|
if engine_id not in self.dst_num_blocks:
|
||||||
self.vllm_config.kv_transfer_config is not None
|
|
||||||
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
|
|
||||||
and nixl_agent_meta.kv_cache_layout == "HND"
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"Remote is HND and local is NHD, enabled additional permute "
|
|
||||||
"on local device KV."
|
|
||||||
)
|
|
||||||
self.enable_permute_local_kv = True
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Heterogeneous TP expects same kv_cache_layout. "
|
|
||||||
"Or enable experimental feature to use HND to NHD support by "
|
|
||||||
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
|
|
||||||
)
|
|
||||||
if self.use_mla or is_kv_replicated:
|
|
||||||
# With replicated KV cache, only the number of blocks can differ.
|
|
||||||
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
|
|
||||||
"KV cache sizes must match between P and D when replicated"
|
|
||||||
)
|
|
||||||
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
|
|
||||||
else:
|
|
||||||
# When MLA is not used, this is a list of the same block length
|
|
||||||
for block_len in nixl_agent_meta.block_lens:
|
|
||||||
assert block_len == remote_block_len, (
|
|
||||||
"All remote layers must have the same block size"
|
|
||||||
)
|
|
||||||
remote_block_size = remote_block_len // (
|
|
||||||
self.slot_size_per_layer[0] * tp_ratio
|
|
||||||
)
|
|
||||||
if self._use_flashinfer:
|
|
||||||
# With flashinfer, KV are sent in the same message.
|
|
||||||
remote_block_size //= 2
|
|
||||||
if tp_ratio > 1:
|
|
||||||
# Heterogeneous TP expects same kv_cache_layout.
|
|
||||||
if nixl_agent_meta.kv_cache_layout == "NHD":
|
|
||||||
raise ValueError(
|
|
||||||
"Heterogeneous TP is not supported for remote with NHD."
|
|
||||||
)
|
|
||||||
if self.device_type == "xpu":
|
|
||||||
raise ValueError("Heterogeneous TP is not supported on XPU")
|
|
||||||
|
|
||||||
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
|
||||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
|
||||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self.block_size == remote_block_size, (
|
|
||||||
"Remote P worker with different page/block size is not supported "
|
|
||||||
f"{self.block_size=}, {remote_block_size=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create dst descs and xfer side handles. TP workers have same #blocks.
|
|
||||||
if engine_id in self.dst_num_blocks:
|
|
||||||
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks
|
|
||||||
else:
|
|
||||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||||
|
|
||||||
|
# Keep track of remote agent kv caches base addresses.
|
||||||
|
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||||
|
|
||||||
|
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
|
||||||
|
|
||||||
|
# Number of D TP workers reading from a single P TP worker. This is
|
||||||
|
# 1 when P and D `--tensor-parallel-size` match.
|
||||||
|
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)
|
||||||
|
|
||||||
|
### Register remote agent memory regions
|
||||||
blocks_data = []
|
blocks_data = []
|
||||||
# With homogeneous TP, D pulls the whole kv cache from corresponding
|
# With homogeneous TP, D pulls the whole kv cache from corresponding
|
||||||
# rank. With heterogeneous TP, prepare the descriptors by splitting the
|
# rank. With heterogeneous TP, prepare the descriptors by splitting the
|
||||||
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
|
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
|
||||||
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
||||||
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
|
|
||||||
|
|
||||||
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
|
|
||||||
# Register all remote blocks, but only the corresponding kv heads.
|
# Register all remote blocks, but only the corresponding kv heads.
|
||||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||||
rank_offset = (
|
rank_offset = (
|
||||||
self.tp_rank % tp_ratio * kv_block_len
|
self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
|
||||||
if not (self.use_mla or is_kv_replicated)
|
|
||||||
else 0
|
|
||||||
)
|
)
|
||||||
for block_id in range(nixl_agent_meta.num_blocks):
|
for block_id in range(nixl_agent_meta.num_blocks):
|
||||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||||
@ -1213,6 +1227,80 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
return remote_agent_name
|
return remote_agent_name
|
||||||
|
|
||||||
|
def _validate_remote_agent_handshake(
|
||||||
|
self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Validate the remote agent handshake metadata ensuring the
|
||||||
|
invariants hold true.
|
||||||
|
"""
|
||||||
|
remote_engine_id = nixl_agent_meta.engine_id
|
||||||
|
|
||||||
|
assert self._tp_size[remote_engine_id] == remote_tp_size
|
||||||
|
# TODO We may eventually want to skip enforcing the same attn backend.
|
||||||
|
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||||
|
|
||||||
|
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
|
||||||
|
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||||
|
assert not self._use_pallas or tp_ratio == 1, (
|
||||||
|
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||||
|
)
|
||||||
|
if not self.use_mla and nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
|
||||||
|
if (
|
||||||
|
self.kv_transfer_config.enable_permute_local_kv
|
||||||
|
and nixl_agent_meta.kv_cache_layout == "HND"
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Remote is HND and local is NHD, enabled additional permute "
|
||||||
|
"on local device KV."
|
||||||
|
)
|
||||||
|
self.enable_permute_local_kv = True
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Heterogeneous TP expects same kv_cache_layout. "
|
||||||
|
"Or enable experimental feature to use HND to NHD support by "
|
||||||
|
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Block len can only vary across layers when using MLA.
|
||||||
|
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||||
|
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
|
||||||
|
# With replicated KV cache, only the number of blocks can differ.
|
||||||
|
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
|
||||||
|
"KV cache sizes must match between P and D when replicated"
|
||||||
|
)
|
||||||
|
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
|
||||||
|
else:
|
||||||
|
if tp_ratio > 1 and self.device_type == "xpu":
|
||||||
|
# XPU uses NHD, hence it does not support splitting on H
|
||||||
|
raise ValueError("Heterogeneous TP is not supported on XPU")
|
||||||
|
# When MLA is not used, this is a list of the same block length
|
||||||
|
for block_len in nixl_agent_meta.block_lens:
|
||||||
|
assert block_len == remote_block_len, (
|
||||||
|
"All remote layers must have the same block size"
|
||||||
|
)
|
||||||
|
remote_block_size = remote_block_len // (
|
||||||
|
self.slot_size_per_layer[0] * tp_ratio
|
||||||
|
)
|
||||||
|
if self._use_flashinfer:
|
||||||
|
# With flashinfer, KV are sent in the same message.
|
||||||
|
remote_block_size //= 2
|
||||||
|
|
||||||
|
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
||||||
|
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||||
|
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self.block_size == remote_block_size, (
|
||||||
|
"Remote P worker with different page/block size is not supported "
|
||||||
|
f"{self.block_size=}, {remote_block_size=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TP workers have same #blocks.
|
||||||
|
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
|
||||||
|
|
||||||
|
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
|
||||||
|
|
||||||
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
|
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
|
||||||
"""copy recved kv from host buffer to device."""
|
"""copy recved kv from host buffer to device."""
|
||||||
assert self.use_host_buffer
|
assert self.use_host_buffer
|
||||||
@ -1505,14 +1593,16 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Number of D TP workers that will read from dst P. Propagate tp_ratio
|
# Number of D TP workers that will read from dst P. Propagate tp_ratio
|
||||||
# on notification so that dst worker can wait before freeing blocks.
|
# on notification so that dst worker can wait before freeing blocks.
|
||||||
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id]
|
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
|
||||||
notif_id = f"{request_id}:{tp_ratio}".encode()
|
notif_id = f"{request_id}:{tp_ratio}".encode()
|
||||||
|
|
||||||
# Full prefix cache hit: do not need to read remote blocks,
|
# Full prefix cache hit: do not need to read remote blocks,
|
||||||
# just notify P worker that we have the blocks we need.
|
# just notify P worker that we have the blocks we need.
|
||||||
num_local_blocks = len(local_block_ids)
|
num_local_blocks = len(local_block_ids)
|
||||||
if num_local_blocks == 0:
|
if num_local_blocks == 0:
|
||||||
remote_rank = self.tp_rank // tp_ratio
|
remote_rank = self.kv_topo.get_target_remote_rank_from_engine_id(
|
||||||
|
dst_engine_id
|
||||||
|
)
|
||||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||||
try:
|
try:
|
||||||
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user