[P/D] Heterogeneous TP (#18833)

Signed-off-by: nicklucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-06-05 01:25:34 +02:00 committed by GitHub
parent 23027e2daf
commit b2fac67130
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 288 additions and 101 deletions

View File

@ -8,7 +8,9 @@ MODELS=(
# Number of prefill and decode instances to create
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
@ -74,9 +76,10 @@ run_tests_for_model() {
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
# Calculate GPU ID - we'll distribute across available GPUs
GPU_ID=$((i % $(get_num_gpus)))
# Calculate port number (base port + instance number)
PORT=$((8100 + i))
# Calculate side channel port
# Calculate side channel port. Avoid clash with with TP workers.
SIDE_CHANNEL_PORT=$((5559 + i))
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
@ -87,6 +90,7 @@ run_tests_for_model() {
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then
@ -109,7 +113,7 @@ run_tests_for_model() {
# Calculate port number (base port + instance number)
PORT=$((8200 + i))
# Calculate side channel port
SIDE_CHANNEL_PORT=$((5659 + i))
SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
@ -119,6 +123,7 @@ run_tests_for_model() {
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--tensor-parallel-size $DECODER_TP_SIZE \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then

View File

@ -14,6 +14,7 @@ RTOL = 0.03
# Model-specific expected values
EXPECTED_VALUES = {
"Qwen/Qwen3-0.6B": 0.41,
"deepseek-ai/deepseek-vl2-small": 0.59
}
SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501

View File

@ -3,11 +3,12 @@
"""
KV cache helper for store.
"""
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -90,3 +91,18 @@ class model_aware_kv_ops_helper:
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
def get_kv_connector_cache_layout():
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if vllm_config.model_config is None:
logger.warning("Unable to detect current VLLM config. " \
"Defaulting to NHD kv cache layout.")
else:
use_mla = vllm_config.model_config.use_mla
if not use_mla and kv_config.kv_connector == "NixlConnector":
logger.info("NixlConnector detected. Setting KV cache " \
"layout to HND for better xfer performance.")
return "HND"
return "NHD"

View File

@ -32,6 +32,7 @@ if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
Transfer = tuple[int, float] # (xfer_handle, start_time)
GET_META_MSG = b"get_meta_msg"
logger = init_logger(__name__)
@ -54,6 +55,8 @@ class NixlAgentMetadata(
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
tp_size: int
block_len: int
@dataclass
@ -331,10 +334,14 @@ class NixlConnectorWorker:
logger.info("Initializing NIXL wrapper")
logger.info("Initializing NIXL worker %s", engine_id)
# Config.
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
# Map of engine_id -> agent_name.
self._remote_agents: dict[str, str] = {}
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
# NIXL handshake port.
# NOTE(rob): Within a DP group, each DP rank gets its own
@ -354,7 +361,8 @@ class NixlConnectorWorker:
# KV Caches and nixl tracking data.
self.kv_caches: dict[str, torch.Tensor] = {}
# Map of engine_id -> kv_caches_base_addr
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[str, list[int]] = {}
# Number of NIXL regions. Currently one region per cache
@ -362,19 +370,19 @@ class NixlConnectorWorker:
self.num_regions = 0
self.num_layers = 0
# nixl_prepped_dlist_handle (int).
# nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[str, int] = {}
# Map of engine_id -> num_blocks.
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
self.dst_num_blocks: dict[str, int] = {}
self._registered_descs: list[Any] = []
# In progress transfers.
# [req_id -> list[handle]]
self._recving_transfers: defaultdict[str, list[Any]] = defaultdict(
list[Any])
self._recving_transfers = defaultdict[str, list[Transfer]](list)
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
@ -395,6 +403,11 @@ class NixlConnectorWorker:
# List of block window sizes for each layer for local attention
self.block_window_per_layer: list[Optional[int]] = []
self._tp_size: dict[str, int] = {self.engine_id: self.world_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[str, int](int)
@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, base_port: int,
@ -426,27 +439,44 @@ class NixlConnectorWorker:
"""Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter()
# NOTE(rob): we need each tp_rank to have a unique port.
# This is a hack to keep us moving. We will switch when
# we switch to HTTP-based NIXL metadata exchange.
path = make_zmq_path("tcp", host, port + self.tp_rank)
logger.debug("Querying metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
def handshake(path: str, rank: int) -> NixlAgentMetadata:
# Send query for the request.
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
# Register Remote agent.
self.add_remote_agent(metadata)
setup_agent_time = time.perf_counter()
# Register Remote agent.
self.add_remote_agent(metadata, rank)
setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
return metadata
# Handshake with remote agent-rank0 first to get the tp_size of remote
path = make_zmq_path("tcp", host, port)
logger.debug("Querying master rank metadata on path: %s", path)
metadata = handshake(path, 0)
# Handshake only with the other TP remote the current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size
p_remote_rank = self.tp_rank // tp_ratio
if p_remote_rank > 0:
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s",
path, p_remote_rank)
_ = handshake(path, p_remote_rank)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
@ -455,24 +485,34 @@ class NixlConnectorWorker:
kv_elem_size = first_kv_cache.element_size()
# TODO(tms): Find a more robust way to detect and handle MLA
use_mla = len(first_kv_cache.shape) == 3
if use_mla:
self.use_mla = len(first_kv_cache.shape) == 3
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
if self.use_mla:
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads, head_dim = block_shape
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)
logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla,
first_kv_cache.shape)
logger.debug("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.debug("Per layer kv cache size: %s", first_kv_cache.shape)
@ -489,7 +529,7 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla else cache_or_caches
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
@ -524,16 +564,37 @@ class NixlConnectorWorker:
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
logger.debug("Done registering descs")
self._registered_descs.append(descs)
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, self.tp_rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.tp_rank)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
# NIXL_INIT_AGENT to be used for preparations of local descs.
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
)
tp_size=self.world_size,
block_len=self.block_len)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
@ -543,50 +604,123 @@ class NixlConnectorWorker:
self._nixl_handshake_listener_t.start()
ready_event.wait()
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
def add_remote_agent(self,
nixl_agent_meta: NixlAgentMetadata,
remote_tp_rank: int = 0):
"""
Add the remote NIXL agent and prepare the descriptors for reading cache
blocks from remote.
In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i.
The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker.
Here's an example:
rank_offset p_remote_tp_rank
(kv split no)
--------------------------------
0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ]
/
1 0 Worker1 ---- 2nd half of KV -----/
0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ]
/
1 1 Worker3 ---- 2nd half of KV -----/
Decoder TP workers Prefix TP workers
(world_size=4) (world_size=2)
tp_ratio = 4 // 2 = 2
Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim]
then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format.
Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio
first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split
along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0.
Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1.
Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0
so that the whole cache is shared by "tp_ratio" D TP workers.
""" # noqa: E501
engine_id = nixl_agent_meta.engine_id
assert engine_id != self.engine_id, "Conflict engine id found!"
if engine_id in self._remote_agents:
# TODO re-evaluate refreshing for scaling/recovery
if remote_tp_rank in self._remote_agents.get(engine_id, ()):
return
self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
if engine_id in self._tp_size:
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
else:
self._tp_size[engine_id] = nixl_agent_meta.tp_size
self._remote_agents[engine_id][
remote_tp_rank] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, \
"Local TP size must be divisible by remote TP size."
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
assert tp_ratio > 0, "Decode TP cannot be smaller than"
" prefill TP"
if self.use_mla:
# With MLA the only difference is in the number of blocks.
remote_block_size = nixl_agent_meta.block_len / (
self.slot_size_bytes)
assert self.block_len == nixl_agent_meta.block_len
else:
remote_block_size = nixl_agent_meta.block_len / (
self.slot_size_bytes * tp_ratio)
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, \
"Remote P worker KV layer cache must be of shape [2, N, \
local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
assert self.block_size == remote_block_size, "Remote P worker with \
different block size is not supported"
assert self.num_blocks >= nixl_agent_meta.num_blocks
# Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks:
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks
else:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
# Create src descs and xfer side handles.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.tp_rank))
logger.debug("Created %s blocks for src engine %s and tp_rank %s",
len(blocks_data), self.engine_id, self.tp_rank)
# With homogeneous TP, D pulls the whole kv cache from corresponding
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
p_remote_tp_rank = self.tp_rank // tp_ratio
# Only register the remote's descriptors if current rank pulls from it.
if p_remote_tp_rank == remote_tp_rank:
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \
if not self.use_mla else 0
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and " \
"local rank %s",
len(blocks_data), engine_id, remote_tp_rank, self.tp_rank)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
# Create dst descs and xfer side handles.
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
blocks_data = []
for base_addr in self.kv_caches_base_addr[engine_id]:
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.tp_rank))
logger.debug("Created %s blocks for dst engine %s and tp_rank %s",
len(blocks_data), engine_id, self.tp_rank)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id], descs)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id][remote_tp_rank], descs)
def get_finished(self) -> tuple[set[str], set[str]]:
"""
@ -654,16 +788,25 @@ class NixlConnectorWorker:
return done_sending, done_recving
def _get_new_notifs(self) -> set[str]:
"""Get req_ids which got a remote xfer message."""
"""
Get req_ids which got a remote xfer message. When multiple consumers
are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling.
"""
notified_req_ids: set[str] = set()
for req_ids in self.nixl_wrapper.get_new_notifs().values():
for req_id in req_ids:
assert req_id not in notified_req_ids
notified_req_ids.add(req_id.decode("utf-8"))
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
self.consumer_notification_counts_by_req[req_id] += 1
# Wait all consumers (D) to be done reading before freeing.
if self.consumer_notification_counts_by_req[req_id] == int(
tp_ratio):
notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id]
return notified_req_ids
def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
def _pop_done_transfers(
self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]:
"""
Pop completed xfers by checking for DONE state.
Args:
@ -673,23 +816,17 @@ class NixlConnectorWorker:
"""
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
running_reqs = []
for handle in handles:
for handle, xfer_stime in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
# TODO ptarasiewicz: why abort is throwing errors?
# self.nixl_wrapper.release_xfer_handle(handle)
self.nixl_wrapper.release_xfer_handle(handle)
done_req_ids.add(req_id)
del transfers[req_id]
elif xfer_state == "PROC":
continue
if xfer_state == "PROC":
running_reqs.append(handle)
else:
raise RuntimeError("Transfer failed with state %s",
xfer_state)
if len(running_reqs) == 0:
done_req_ids.add(req_id)
del transfers[req_id]
else:
transfers[req_id] = running_reqs
return done_req_ids
def start_load_kv(self, metadata: NixlConnectorMetadata):
@ -735,13 +872,19 @@ class NixlConnectorWorker:
# saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready.
# Number of D TP workers that will read from dst P. Propagate tp_ratio
# on notification so that dst worker can wait before freeing blocks.
tp_ratio = self._tp_size[
self.engine_id] // self._tp_size[dst_engine_id]
notif_id = f"{request_id}:{tp_ratio}".encode()
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
agent_name = self._remote_agents[dst_engine_id]
self.nixl_wrapper.send_notif(agent_name,
notif_msg=request_id.encode("utf-8"))
remote_rank = self.tp_rank // tp_ratio
agent_name = self._remote_agents[dst_engine_id][remote_rank]
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
return
# Partial prefix cache hit: just read uncomputed blocks.
@ -754,6 +897,10 @@ class NixlConnectorWorker:
local_xfer_side_handle = self.src_xfer_side_handle
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
# Get descs ids.
local_block_descs_ids: list[int] = []
remote_block_descs_ids: list[int] = []
@ -797,14 +944,16 @@ class NixlConnectorWorker:
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=request_id.encode("utf-8"),
notif_msg=notif_id,
)
# Begin async xfer.
self.nixl_wrapper.transfer(handle)
# Use handle to check completion in future step().
self._recving_transfers[request_id].append(handle)
# TODO (NickLucche) surface xfer elapsed time
self._recving_transfers[request_id].append(
(handle, time.perf_counter()))
def _get_block_descs_ids(self,
engine_id: str,
@ -815,7 +964,6 @@ class NixlConnectorWorker:
If layer_idx is provided, we use the region_ids for the given layer.
Otherwise, we use all regions.
"""
if layer_idx is None:
region_ids = range(self.num_regions)
else:

View File

@ -16,6 +16,8 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
@ -70,6 +72,20 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# NOTE When running disaggregated PD with NIXL, HND layout is used for
# faster transfer. `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_connector_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError("Unknown cache layout format %s.", cache_layout)
return stride_order
@dataclass
class FlashAttentionMetadata:

View File

@ -597,7 +597,8 @@ class WorkerWrapperBase:
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
self.worker.initialize_from_config(kv_cache_config) # type: ignore
with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self):
with set_current_vllm_config(self.vllm_config):