mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Nixl] Minor refactor to handshake related metadata (#26410)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
be4445072c
commit
72f431e709
@ -565,8 +565,6 @@ class TestNixlHandshake:
|
||||
kv_cache_layout=mismatched_layout,
|
||||
)
|
||||
|
||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||
# whole block is moved.
|
||||
with pytest.raises(RuntimeError):
|
||||
# mismatched layout is expected to fail
|
||||
worker.add_remote_agent(meta, remote_tp_size=2)
|
||||
|
||||
@ -36,7 +36,6 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@ -521,6 +520,72 @@ class NixlConnectorScheduler:
|
||||
class NixlConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
Helper class for tensor parallel and KV topology information for
|
||||
mapping between local and remote TP workers.
|
||||
"""
|
||||
|
||||
tp_size: int
|
||||
tp_rank: int
|
||||
remote_tp_size: dict[EngineId, int]
|
||||
is_mla: bool
|
||||
total_num_kv_heads: int
|
||||
|
||||
def tp_ratio(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the tensor parallel ratio between local and remote TP.
|
||||
We can think of it as the number of local TP workers-per-remote TP
|
||||
workers. Local workers will read from the same remote TP worker in
|
||||
groups of size `tp_ratio`.
|
||||
"""
|
||||
assert self.tp_size % remote_tp_size == 0, (
|
||||
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||
f"by remote tensor parallel size {remote_tp_size}."
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
|
||||
def tp_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.tp_ratio(remote_tp_size)
|
||||
|
||||
def is_kv_replicated(self, engine_id: EngineId) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
number of TP workers being greater than the number of KV heads.
|
||||
"""
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
return tp_size // self.total_num_kv_heads >= 1
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
|
||||
# MLA is always replicated as the hidden dim can't be split.
|
||||
return self.is_mla or self.is_kv_replicated(remote_engine_id)
|
||||
|
||||
def get_target_remote_rank(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
(on D) will read from.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
return self.tp_rank // tp_ratio
|
||||
|
||||
def get_target_remote_rank_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_rank(remote_tp_size)
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
if NixlWrapper is None:
|
||||
logger.error("NIXL is not available")
|
||||
@ -534,6 +599,7 @@ class NixlConnectorWorker:
|
||||
|
||||
if vllm_config.kv_transfer_config is None:
|
||||
raise ValueError("kv_transfer_config must be set for NixlConnector")
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
|
||||
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"backends", ["UCX"]
|
||||
@ -654,7 +720,6 @@ class NixlConnectorWorker:
|
||||
# Protects _handshake_futures and _remote_agents.
|
||||
self._handshake_lock = threading.RLock()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
@ -686,6 +751,14 @@ class NixlConnectorWorker:
|
||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||
self.xfer_stats = NixlKVConnectorStats()
|
||||
|
||||
self.kv_topo = self.TpKVTopology(
|
||||
tp_size=self.world_size,
|
||||
tp_rank=self.tp_rank,
|
||||
remote_tp_size=self._tp_size, # shared state
|
||||
is_mla=self.use_mla,
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(
|
||||
metadata: NixlAgentMetadata,
|
||||
@ -731,8 +804,7 @@ class NixlConnectorWorker:
|
||||
|
||||
# Handshake only with the remote TP rank that current local rank will
|
||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
|
||||
p_remote_rank = self.tp_rank // tp_ratio
|
||||
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
|
||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
||||
logger.debug(
|
||||
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
|
||||
@ -989,13 +1061,11 @@ class NixlConnectorWorker:
|
||||
|
||||
# TODO(mgoin): Hybrid memory allocator is currently disabled for
|
||||
# models with local attention (Llama 4). Can remove this once enabled.
|
||||
if self.vllm_config.model_config.hf_config.model_type == "llama4":
|
||||
if self.model_config.hf_config.model_type == "llama4":
|
||||
from transformers import Llama4TextConfig
|
||||
|
||||
assert isinstance(
|
||||
self.vllm_config.model_config.hf_text_config, Llama4TextConfig
|
||||
)
|
||||
llama4_config = self.vllm_config.model_config.hf_text_config
|
||||
assert isinstance(self.model_config.hf_text_config, Llama4TextConfig)
|
||||
llama4_config = self.model_config.hf_text_config
|
||||
no_rope_layers = llama4_config.no_rope_layers
|
||||
chunk_size = llama4_config.attention_chunk_size
|
||||
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||
@ -1078,107 +1148,51 @@ class NixlConnectorWorker:
|
||||
engine_id = nixl_agent_meta.engine_id
|
||||
# TODO re-evaluate refreshing for scaling/recovery
|
||||
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
||||
logger.debug(
|
||||
"Remote agent with engine_id %s and rank"
|
||||
"%s already exchanged metadata, skip handshake.",
|
||||
engine_id,
|
||||
remote_tp_rank,
|
||||
)
|
||||
return self._remote_agents[engine_id][remote_tp_rank]
|
||||
|
||||
### Register remote agent metadata
|
||||
if engine_id not in self._tp_size:
|
||||
self._tp_size[engine_id] = remote_tp_size
|
||||
else:
|
||||
assert self._tp_size[engine_id] == remote_tp_size
|
||||
# TODO We may eventually want to skip enforcing the same attn backend.
|
||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||
|
||||
remote_agent_name = self.nixl_wrapper.add_remote_agent(
|
||||
nixl_agent_meta.agent_metadata
|
||||
)
|
||||
|
||||
# Number of D TP workers reading from a single P TP worker. This is
|
||||
# 1 when P and D `--tensor-parallel-size` match.
|
||||
tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id])
|
||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||
assert not self._use_pallas or tp_ratio == 1, (
|
||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||
)
|
||||
|
||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
|
||||
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
|
||||
replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id)
|
||||
|
||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||
if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
|
||||
if (
|
||||
self.vllm_config.kv_transfer_config is not None
|
||||
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
|
||||
and nixl_agent_meta.kv_cache_layout == "HND"
|
||||
):
|
||||
logger.info(
|
||||
"Remote is HND and local is NHD, enabled additional permute "
|
||||
"on local device KV."
|
||||
)
|
||||
self.enable_permute_local_kv = True
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Heterogeneous TP expects same kv_cache_layout. "
|
||||
"Or enable experimental feature to use HND to NHD support by "
|
||||
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
|
||||
)
|
||||
if self.use_mla or is_kv_replicated:
|
||||
# With replicated KV cache, only the number of blocks can differ.
|
||||
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
|
||||
"KV cache sizes must match between P and D when replicated"
|
||||
)
|
||||
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
|
||||
else:
|
||||
# When MLA is not used, this is a list of the same block length
|
||||
for block_len in nixl_agent_meta.block_lens:
|
||||
assert block_len == remote_block_len, (
|
||||
"All remote layers must have the same block size"
|
||||
)
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0] * tp_ratio
|
||||
)
|
||||
if self._use_flashinfer:
|
||||
# With flashinfer, KV are sent in the same message.
|
||||
remote_block_size //= 2
|
||||
if tp_ratio > 1:
|
||||
# Heterogeneous TP expects same kv_cache_layout.
|
||||
if nixl_agent_meta.kv_cache_layout == "NHD":
|
||||
raise ValueError(
|
||||
"Heterogeneous TP is not supported for remote with NHD."
|
||||
)
|
||||
if self.device_type == "xpu":
|
||||
raise ValueError("Heterogeneous TP is not supported on XPU")
|
||||
|
||||
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
)
|
||||
|
||||
assert self.block_size == remote_block_size, (
|
||||
"Remote P worker with different page/block size is not supported "
|
||||
f"{self.block_size=}, {remote_block_size=}"
|
||||
)
|
||||
|
||||
# Create dst descs and xfer side handles. TP workers have same #blocks.
|
||||
if engine_id in self.dst_num_blocks:
|
||||
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks
|
||||
else:
|
||||
# Create dst descs and xfer side handles. TP workers have same #blocks
|
||||
# so we only register once per engine_id.
|
||||
if engine_id not in self.dst_num_blocks:
|
||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||
|
||||
# Keep track of remote agent kv caches base addresses.
|
||||
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
|
||||
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
|
||||
|
||||
# Number of D TP workers reading from a single P TP worker. This is
|
||||
# 1 when P and D `--tensor-parallel-size` match.
|
||||
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)
|
||||
|
||||
### Register remote agent memory regions
|
||||
blocks_data = []
|
||||
# With homogeneous TP, D pulls the whole kv cache from corresponding
|
||||
# rank. With heterogeneous TP, prepare the descriptors by splitting the
|
||||
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
|
||||
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
||||
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
|
||||
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
rank_offset = (
|
||||
self.tp_rank % tp_ratio * kv_block_len
|
||||
if not (self.use_mla or is_kv_replicated)
|
||||
else 0
|
||||
self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
|
||||
)
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
@ -1213,6 +1227,80 @@ class NixlConnectorWorker:
|
||||
|
||||
return remote_agent_name
|
||||
|
||||
def _validate_remote_agent_handshake(
|
||||
self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int
|
||||
):
|
||||
"""
|
||||
Validate the remote agent handshake metadata ensuring the
|
||||
invariants hold true.
|
||||
"""
|
||||
remote_engine_id = nixl_agent_meta.engine_id
|
||||
|
||||
assert self._tp_size[remote_engine_id] == remote_tp_size
|
||||
# TODO We may eventually want to skip enforcing the same attn backend.
|
||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||
|
||||
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
|
||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||
assert not self._use_pallas or tp_ratio == 1, (
|
||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||
)
|
||||
if not self.use_mla and nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
|
||||
if (
|
||||
self.kv_transfer_config.enable_permute_local_kv
|
||||
and nixl_agent_meta.kv_cache_layout == "HND"
|
||||
):
|
||||
logger.info(
|
||||
"Remote is HND and local is NHD, enabled additional permute "
|
||||
"on local device KV."
|
||||
)
|
||||
self.enable_permute_local_kv = True
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Heterogeneous TP expects same kv_cache_layout. "
|
||||
"Or enable experimental feature to use HND to NHD support by "
|
||||
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
|
||||
)
|
||||
|
||||
# Block len can only vary across layers when using MLA.
|
||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
|
||||
# With replicated KV cache, only the number of blocks can differ.
|
||||
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
|
||||
"KV cache sizes must match between P and D when replicated"
|
||||
)
|
||||
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
|
||||
else:
|
||||
if tp_ratio > 1 and self.device_type == "xpu":
|
||||
# XPU uses NHD, hence it does not support splitting on H
|
||||
raise ValueError("Heterogeneous TP is not supported on XPU")
|
||||
# When MLA is not used, this is a list of the same block length
|
||||
for block_len in nixl_agent_meta.block_lens:
|
||||
assert block_len == remote_block_len, (
|
||||
"All remote layers must have the same block size"
|
||||
)
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0] * tp_ratio
|
||||
)
|
||||
if self._use_flashinfer:
|
||||
# With flashinfer, KV are sent in the same message.
|
||||
remote_block_size //= 2
|
||||
|
||||
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
)
|
||||
|
||||
assert self.block_size == remote_block_size, (
|
||||
"Remote P worker with different page/block size is not supported "
|
||||
f"{self.block_size=}, {remote_block_size=}"
|
||||
)
|
||||
|
||||
# TP workers have same #blocks.
|
||||
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
|
||||
|
||||
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
|
||||
|
||||
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
|
||||
"""copy recved kv from host buffer to device."""
|
||||
assert self.use_host_buffer
|
||||
@ -1505,14 +1593,16 @@ class NixlConnectorWorker:
|
||||
|
||||
# Number of D TP workers that will read from dst P. Propagate tp_ratio
|
||||
# on notification so that dst worker can wait before freeing blocks.
|
||||
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id]
|
||||
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
|
||||
notif_id = f"{request_id}:{tp_ratio}".encode()
|
||||
|
||||
# Full prefix cache hit: do not need to read remote blocks,
|
||||
# just notify P worker that we have the blocks we need.
|
||||
num_local_blocks = len(local_block_ids)
|
||||
if num_local_blocks == 0:
|
||||
remote_rank = self.tp_rank // tp_ratio
|
||||
remote_rank = self.kv_topo.get_target_remote_rank_from_engine_id(
|
||||
dst_engine_id
|
||||
)
|
||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||
try:
|
||||
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user