mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-09 18:57:08 +08:00
[P/D] Heterogeneous TP (#18833)
Signed-off-by: nicklucche <nlucches@redhat.com>
This commit is contained in:
parent
23027e2daf
commit
b2fac67130
@ -8,7 +8,9 @@ MODELS=(
|
||||
|
||||
# Number of prefill and decode instances to create
|
||||
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
|
||||
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2
|
||||
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
|
||||
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
|
||||
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
|
||||
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
@ -74,9 +76,10 @@ run_tests_for_model() {
|
||||
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
|
||||
# Calculate GPU ID - we'll distribute across available GPUs
|
||||
GPU_ID=$((i % $(get_num_gpus)))
|
||||
|
||||
# Calculate port number (base port + instance number)
|
||||
PORT=$((8100 + i))
|
||||
# Calculate side channel port
|
||||
# Calculate side channel port. Avoid clash with with TP workers.
|
||||
SIDE_CHANNEL_PORT=$((5559 + i))
|
||||
|
||||
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
|
||||
@ -87,6 +90,7 @@ run_tests_for_model() {
|
||||
--enforce-eager \
|
||||
--disable-log-requests \
|
||||
--gpu-memory-utilization 0.2 \
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
@ -109,7 +113,7 @@ run_tests_for_model() {
|
||||
# Calculate port number (base port + instance number)
|
||||
PORT=$((8200 + i))
|
||||
# Calculate side channel port
|
||||
SIDE_CHANNEL_PORT=$((5659 + i))
|
||||
SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
|
||||
|
||||
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
|
||||
|
||||
@ -119,6 +123,7 @@ run_tests_for_model() {
|
||||
--enforce-eager \
|
||||
--disable-log-requests \
|
||||
--gpu-memory-utilization 0.2 \
|
||||
--tensor-parallel-size $DECODER_TP_SIZE \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
|
||||
@ -14,6 +14,7 @@ RTOL = 0.03
|
||||
# Model-specific expected values
|
||||
EXPECTED_VALUES = {
|
||||
"Qwen/Qwen3-0.6B": 0.41,
|
||||
"deepseek-ai/deepseek-vl2-small": 0.59
|
||||
}
|
||||
|
||||
SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501
|
||||
|
||||
@ -3,11 +3,12 @@
|
||||
"""
|
||||
KV cache helper for store.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -90,3 +91,18 @@ class model_aware_kv_ops_helper:
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
|
||||
|
||||
def get_kv_connector_cache_layout():
|
||||
vllm_config = get_current_vllm_config()
|
||||
kv_config = vllm_config.kv_transfer_config
|
||||
if vllm_config.model_config is None:
|
||||
logger.warning("Unable to detect current VLLM config. " \
|
||||
"Defaulting to NHD kv cache layout.")
|
||||
else:
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
if not use_mla and kv_config.kv_connector == "NixlConnector":
|
||||
logger.info("NixlConnector detected. Setting KV cache " \
|
||||
"layout to HND for better xfer performance.")
|
||||
return "HND"
|
||||
return "NHD"
|
||||
|
||||
@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -54,6 +55,8 @@ class NixlAgentMetadata(
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
tp_size: int
|
||||
block_len: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -331,10 +334,14 @@ class NixlConnectorWorker:
|
||||
logger.info("Initializing NIXL wrapper")
|
||||
logger.info("Initializing NIXL worker %s", engine_id)
|
||||
|
||||
# Config.
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Agent.
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
||||
# Map of engine_id -> agent_name.
|
||||
self._remote_agents: dict[str, str] = {}
|
||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
|
||||
|
||||
# NIXL handshake port.
|
||||
# NOTE(rob): Within a DP group, each DP rank gets its own
|
||||
@ -354,7 +361,8 @@ class NixlConnectorWorker:
|
||||
# KV Caches and nixl tracking data.
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Map of engine_id -> kv_caches_base_addr
|
||||
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
||||
# rank will still only pull from a single remote TP worker.
|
||||
self.kv_caches_base_addr: dict[str, list[int]] = {}
|
||||
|
||||
# Number of NIXL regions. Currently one region per cache
|
||||
@ -362,19 +370,19 @@ class NixlConnectorWorker:
|
||||
self.num_regions = 0
|
||||
self.num_layers = 0
|
||||
|
||||
# nixl_prepped_dlist_handle (int).
|
||||
# nixl_prepped_dlist_handle.
|
||||
self.src_xfer_side_handle: int = 0
|
||||
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
||||
self.dst_xfer_side_handles: dict[str, int] = {}
|
||||
|
||||
# Map of engine_id -> num_blocks.
|
||||
# Map of engine_id -> num_blocks. All ranks in the same deployment will
|
||||
# have the same number of blocks.
|
||||
self.dst_num_blocks: dict[str, int] = {}
|
||||
self._registered_descs: list[Any] = []
|
||||
|
||||
# In progress transfers.
|
||||
# [req_id -> list[handle]]
|
||||
self._recving_transfers: defaultdict[str, list[Any]] = defaultdict(
|
||||
list[Any])
|
||||
self._recving_transfers = defaultdict[str, list[Transfer]](list)
|
||||
|
||||
# Complete transfer tracker. Used by the rank 0 to track finished
|
||||
# transactions on ranks 1 to N-1.
|
||||
@ -395,6 +403,11 @@ class NixlConnectorWorker:
|
||||
# List of block window sizes for each layer for local attention
|
||||
self.block_window_per_layer: list[Optional[int]] = []
|
||||
|
||||
self._tp_size: dict[str, int] = {self.engine_id: self.world_size}
|
||||
# With heterogeneous TP, P must wait for all assigned D TP workers to
|
||||
# finish reading before safely freeing the blocks.
|
||||
self.consumer_notification_counts_by_req = defaultdict[str, int](int)
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event, base_port: int,
|
||||
@ -426,27 +439,44 @@ class NixlConnectorWorker:
|
||||
"""Do a NIXL handshake with a remote instance."""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
# NOTE(rob): we need each tp_rank to have a unique port.
|
||||
# This is a hack to keep us moving. We will switch when
|
||||
# we switch to HTTP-based NIXL metadata exchange.
|
||||
path = make_zmq_path("tcp", host, port + self.tp_rank)
|
||||
logger.debug("Querying metadata on path: %s", path)
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
|
||||
# NOTE(rob): we need each rank to have a unique port. This is
|
||||
# a hack to keep us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
|
||||
def handshake(path: str, rank: int) -> NixlAgentMetadata:
|
||||
# Send query for the request.
|
||||
sock.send(GET_META_MSG)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
got_metadata_time = time.perf_counter()
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
sock.send(GET_META_MSG)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
got_metadata_time = time.perf_counter()
|
||||
|
||||
# Register Remote agent.
|
||||
self.add_remote_agent(metadata)
|
||||
setup_agent_time = time.perf_counter()
|
||||
# Register Remote agent.
|
||||
self.add_remote_agent(metadata, rank)
|
||||
setup_agent_time = time.perf_counter()
|
||||
|
||||
logger.debug("NIXL handshake: get metadata took: %s",
|
||||
got_metadata_time - start_time)
|
||||
logger.debug("NIXL handshake: add agent took: %s",
|
||||
setup_agent_time - got_metadata_time)
|
||||
logger.debug("NIXL handshake: get metadata took: %s",
|
||||
got_metadata_time - start_time)
|
||||
logger.debug("NIXL handshake: add agent took: %s",
|
||||
setup_agent_time - got_metadata_time)
|
||||
return metadata
|
||||
|
||||
# Handshake with remote agent-rank0 first to get the tp_size of remote
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
logger.debug("Querying master rank metadata on path: %s", path)
|
||||
metadata = handshake(path, 0)
|
||||
|
||||
# Handshake only with the other TP remote the current local rank will
|
||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||
tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size
|
||||
p_remote_rank = self.tp_rank // tp_ratio
|
||||
if p_remote_rank > 0:
|
||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
||||
logger.debug("Querying metadata on path: %s at remote rank %s",
|
||||
path, p_remote_rank)
|
||||
_ = handshake(path, p_remote_rank)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in nixl."""
|
||||
@ -455,24 +485,34 @@ class NixlConnectorWorker:
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
use_mla = len(first_kv_cache.shape) == 3
|
||||
if use_mla:
|
||||
self.use_mla = len(first_kv_cache.shape) == 3
|
||||
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
|
||||
# KV memory layout is HND, as opposed to the default NHD. Note that it
|
||||
# will only affects the strides. For MLA instead, we make require no
|
||||
# such thing and resort to the standard layout.
|
||||
if self.use_mla:
|
||||
# MLA case.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 2 # [block_size, latent_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, kv_latent_dim = block_shape
|
||||
self.slot_size_bytes = kv_elem_size * kv_latent_dim
|
||||
else:
|
||||
# [2 (k and v), num_blocks, ...]
|
||||
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
|
||||
block_size, n_kv_heads, head_dim = block_shape
|
||||
# head size in bytes.
|
||||
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
|
||||
assert block_size == self.block_size
|
||||
# TODO(tms): self.block_len needs to be per-layer for sliding window,
|
||||
# hybrid attn, etc
|
||||
# block size in bytes
|
||||
self.block_len = kv_elem_size * math.prod(block_shape)
|
||||
|
||||
logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla,
|
||||
first_kv_cache.shape)
|
||||
logger.debug("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
logger.debug("Per layer kv cache size: %s", first_kv_cache.shape)
|
||||
@ -489,7 +529,7 @@ class NixlConnectorWorker:
|
||||
# (roughly 8KB vs 5KB).
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
cache_list = [cache_or_caches] if use_mla else cache_or_caches
|
||||
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len
|
||||
@ -524,16 +564,37 @@ class NixlConnectorWorker:
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
self.nixl_wrapper.register_memory(descs)
|
||||
logger.debug("Done registering descs")
|
||||
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
# Register local/src descr for NIXL xfer.
|
||||
blocks_data = []
|
||||
for base_addr in self.kv_caches_base_addr[self.engine_id]:
|
||||
# NOTE With heter-TP, more blocks are prepared than what are
|
||||
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
|
||||
# could create fewer, but then _get_block_descs_ids needs to
|
||||
# select agent_meta.num_blocks instead of self.num_blocks for
|
||||
# local descr, and that makes handling regular flow less clean.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, self.block_len, self.tp_rank))
|
||||
logger.debug("Created %s blocks for src engine %s and rank %s",
|
||||
len(blocks_data), self.engine_id, self.tp_rank)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs)
|
||||
|
||||
# After KV Caches registered, listen for new connections.
|
||||
metadata = NixlAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
num_blocks=self.num_blocks,
|
||||
)
|
||||
tp_size=self.world_size,
|
||||
block_len=self.block_len)
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
@ -543,50 +604,123 @@ class NixlConnectorWorker:
|
||||
self._nixl_handshake_listener_t.start()
|
||||
ready_event.wait()
|
||||
|
||||
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
|
||||
def add_remote_agent(self,
|
||||
nixl_agent_meta: NixlAgentMetadata,
|
||||
remote_tp_rank: int = 0):
|
||||
"""
|
||||
Add the remote NIXL agent and prepare the descriptors for reading cache
|
||||
blocks from remote.
|
||||
|
||||
In particular, handle both homogeneous and heterogeneous TP. The former
|
||||
requires local rank_i to read from remote rank_i.
|
||||
The latter, assuming D.world_size > P.world_size, requires that two or
|
||||
more local TP worker share the xfer from a single TP worker.
|
||||
|
||||
Here's an example:
|
||||
|
||||
rank_offset p_remote_tp_rank
|
||||
(kv split no)
|
||||
--------------------------------
|
||||
0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ]
|
||||
/
|
||||
1 0 Worker1 ---- 2nd half of KV -----/
|
||||
|
||||
0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ]
|
||||
/
|
||||
1 1 Worker3 ---- 2nd half of KV -----/
|
||||
|
||||
|
||||
Decoder TP workers Prefix TP workers
|
||||
(world_size=4) (world_size=2)
|
||||
tp_ratio = 4 // 2 = 2
|
||||
|
||||
Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim]
|
||||
then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format.
|
||||
Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio
|
||||
first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split
|
||||
along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0.
|
||||
|
||||
Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1.
|
||||
|
||||
Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0
|
||||
so that the whole cache is shared by "tp_ratio" D TP workers.
|
||||
""" # noqa: E501
|
||||
engine_id = nixl_agent_meta.engine_id
|
||||
assert engine_id != self.engine_id, "Conflict engine id found!"
|
||||
if engine_id in self._remote_agents:
|
||||
# TODO re-evaluate refreshing for scaling/recovery
|
||||
if remote_tp_rank in self._remote_agents.get(engine_id, ()):
|
||||
return
|
||||
|
||||
self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
|
||||
nixl_agent_meta.agent_metadata)
|
||||
self.kv_caches_base_addr[
|
||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
if engine_id in self._tp_size:
|
||||
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
|
||||
else:
|
||||
self._tp_size[engine_id] = nixl_agent_meta.tp_size
|
||||
self._remote_agents[engine_id][
|
||||
remote_tp_rank] = self.nixl_wrapper.add_remote_agent(
|
||||
nixl_agent_meta.agent_metadata)
|
||||
|
||||
# Number of D TP workers reading from a single P TP worker. This is
|
||||
# 1 when P and D `--tensor-parallel-size` match.
|
||||
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, \
|
||||
"Local TP size must be divisible by remote TP size."
|
||||
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
|
||||
assert tp_ratio > 0, "Decode TP cannot be smaller than"
|
||||
" prefill TP"
|
||||
if self.use_mla:
|
||||
# With MLA the only difference is in the number of blocks.
|
||||
remote_block_size = nixl_agent_meta.block_len / (
|
||||
self.slot_size_bytes)
|
||||
assert self.block_len == nixl_agent_meta.block_len
|
||||
else:
|
||||
remote_block_size = nixl_agent_meta.block_len / (
|
||||
self.slot_size_bytes * tp_ratio)
|
||||
|
||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, \
|
||||
"Remote P worker KV layer cache must be of shape [2, N, \
|
||||
local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
|
||||
assert self.block_size == remote_block_size, "Remote P worker with \
|
||||
different block size is not supported"
|
||||
|
||||
assert self.num_blocks >= nixl_agent_meta.num_blocks
|
||||
|
||||
# Create dst descs and xfer side handles. TP workers have same #blocks.
|
||||
if engine_id in self.dst_num_blocks:
|
||||
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks
|
||||
else:
|
||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||
|
||||
# Create src descs and xfer side handles.
|
||||
blocks_data = []
|
||||
for base_addr in self.kv_caches_base_addr[self.engine_id]:
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
# (addr, len, device id)
|
||||
blocks_data.append(
|
||||
(base_addr + block_offset, self.block_len, self.tp_rank))
|
||||
logger.debug("Created %s blocks for src engine %s and tp_rank %s",
|
||||
len(blocks_data), self.engine_id, self.tp_rank)
|
||||
# With homogeneous TP, D pulls the whole kv cache from corresponding
|
||||
# rank. With heterogeneous TP, prepare the descriptors by splitting the
|
||||
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
|
||||
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
||||
p_remote_tp_rank = self.tp_rank // tp_ratio
|
||||
# Only register the remote's descriptors if current rank pulls from it.
|
||||
if p_remote_tp_rank == remote_tp_rank:
|
||||
self.kv_caches_base_addr[
|
||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
rank_offset = self.tp_rank % tp_ratio * self.block_len \
|
||||
if not self.use_mla else 0
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_len
|
||||
# For each block, grab the heads chunk belonging to rank_i
|
||||
# of size remote_nheads // tp_ratio, which correspond to
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, self.block_len, remote_tp_rank))
|
||||
logger.debug(
|
||||
"Created %s blocks for dst engine %s with remote rank %s and " \
|
||||
"local rank %s",
|
||||
len(blocks_data), engine_id, remote_tp_rank, self.tp_rank)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs)
|
||||
|
||||
# Create dst descs and xfer side handles.
|
||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||
blocks_data = []
|
||||
for base_addr in self.kv_caches_base_addr[engine_id]:
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
# (addr, len, device id)
|
||||
blocks_data.append(
|
||||
(base_addr + block_offset, self.block_len, self.tp_rank))
|
||||
logger.debug("Created %s blocks for dst engine %s and tp_rank %s",
|
||||
len(blocks_data), engine_id, self.tp_rank)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.dst_xfer_side_handles[
|
||||
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
||||
self._remote_agents[engine_id], descs)
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.dst_xfer_side_handles[
|
||||
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
||||
self._remote_agents[engine_id][remote_tp_rank], descs)
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
@ -654,16 +788,25 @@ class NixlConnectorWorker:
|
||||
return done_sending, done_recving
|
||||
|
||||
def _get_new_notifs(self) -> set[str]:
|
||||
"""Get req_ids which got a remote xfer message."""
|
||||
|
||||
"""
|
||||
Get req_ids which got a remote xfer message. When multiple consumers
|
||||
are reading from the same producer (heterogeneous TP scenario), wait
|
||||
for all consumers to be done pulling.
|
||||
"""
|
||||
notified_req_ids: set[str] = set()
|
||||
for req_ids in self.nixl_wrapper.get_new_notifs().values():
|
||||
for req_id in req_ids:
|
||||
assert req_id not in notified_req_ids
|
||||
notified_req_ids.add(req_id.decode("utf-8"))
|
||||
for notifs in self.nixl_wrapper.get_new_notifs().values():
|
||||
for notif in notifs:
|
||||
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
|
||||
self.consumer_notification_counts_by_req[req_id] += 1
|
||||
# Wait all consumers (D) to be done reading before freeing.
|
||||
if self.consumer_notification_counts_by_req[req_id] == int(
|
||||
tp_ratio):
|
||||
notified_req_ids.add(req_id)
|
||||
del self.consumer_notification_counts_by_req[req_id]
|
||||
return notified_req_ids
|
||||
|
||||
def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
|
||||
def _pop_done_transfers(
|
||||
self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]:
|
||||
"""
|
||||
Pop completed xfers by checking for DONE state.
|
||||
Args:
|
||||
@ -673,23 +816,17 @@ class NixlConnectorWorker:
|
||||
"""
|
||||
done_req_ids: set[str] = set()
|
||||
for req_id, handles in list(transfers.items()):
|
||||
running_reqs = []
|
||||
for handle in handles:
|
||||
for handle, xfer_stime in handles:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
if xfer_state == "DONE":
|
||||
# TODO ptarasiewicz: why abort is throwing errors?
|
||||
# self.nixl_wrapper.release_xfer_handle(handle)
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
done_req_ids.add(req_id)
|
||||
del transfers[req_id]
|
||||
elif xfer_state == "PROC":
|
||||
continue
|
||||
if xfer_state == "PROC":
|
||||
running_reqs.append(handle)
|
||||
else:
|
||||
raise RuntimeError("Transfer failed with state %s",
|
||||
xfer_state)
|
||||
if len(running_reqs) == 0:
|
||||
done_req_ids.add(req_id)
|
||||
del transfers[req_id]
|
||||
else:
|
||||
transfers[req_id] = running_reqs
|
||||
return done_req_ids
|
||||
|
||||
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
||||
@ -735,13 +872,19 @@ class NixlConnectorWorker:
|
||||
# saturate IB with heterogeneous TP sizes. We should remove the staging
|
||||
# blocks until we are ready.
|
||||
|
||||
# Number of D TP workers that will read from dst P. Propagate tp_ratio
|
||||
# on notification so that dst worker can wait before freeing blocks.
|
||||
tp_ratio = self._tp_size[
|
||||
self.engine_id] // self._tp_size[dst_engine_id]
|
||||
notif_id = f"{request_id}:{tp_ratio}".encode()
|
||||
|
||||
# Full prefix cache hit: do not need to read remote blocks,
|
||||
# just notify P worker that we have the blocks we need.
|
||||
num_local_blocks = len(local_block_ids)
|
||||
if num_local_blocks == 0:
|
||||
agent_name = self._remote_agents[dst_engine_id]
|
||||
self.nixl_wrapper.send_notif(agent_name,
|
||||
notif_msg=request_id.encode("utf-8"))
|
||||
remote_rank = self.tp_rank // tp_ratio
|
||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
||||
return
|
||||
|
||||
# Partial prefix cache hit: just read uncomputed blocks.
|
||||
@ -754,6 +897,10 @@ class NixlConnectorWorker:
|
||||
local_xfer_side_handle = self.src_xfer_side_handle
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
|
||||
|
||||
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
||||
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
|
||||
# workers will issue xfers to parts of the P worker remote kv caches.
|
||||
|
||||
# Get descs ids.
|
||||
local_block_descs_ids: list[int] = []
|
||||
remote_block_descs_ids: list[int] = []
|
||||
@ -797,14 +944,16 @@ class NixlConnectorWorker:
|
||||
local_block_descs_ids,
|
||||
remote_xfer_side_handle,
|
||||
remote_block_descs_ids,
|
||||
notif_msg=request_id.encode("utf-8"),
|
||||
notif_msg=notif_id,
|
||||
)
|
||||
|
||||
# Begin async xfer.
|
||||
self.nixl_wrapper.transfer(handle)
|
||||
|
||||
# Use handle to check completion in future step().
|
||||
self._recving_transfers[request_id].append(handle)
|
||||
# TODO (NickLucche) surface xfer elapsed time
|
||||
self._recving_transfers[request_id].append(
|
||||
(handle, time.perf_counter()))
|
||||
|
||||
def _get_block_descs_ids(self,
|
||||
engine_id: str,
|
||||
@ -815,7 +964,6 @@ class NixlConnectorWorker:
|
||||
If layer_idx is provided, we use the region_ids for the given layer.
|
||||
Otherwise, we use all regions.
|
||||
"""
|
||||
|
||||
if layer_idx is None:
|
||||
region_ids = range(self.num_regions)
|
||||
else:
|
||||
|
||||
@ -16,6 +16,8 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
@ -70,6 +72,20 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
# NOTE When running disaggregated PD with NIXL, HND layout is used for
|
||||
# faster transfer. `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
else:
|
||||
raise ValueError("Unknown cache layout format %s.", cache_layout)
|
||||
return stride_order
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
|
||||
@ -597,7 +597,8 @@ class WorkerWrapperBase:
|
||||
|
||||
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
|
||||
kv_cache_config = kv_cache_configs[self.rpc_rank]
|
||||
self.worker.initialize_from_config(kv_cache_config) # type: ignore
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
self.worker.initialize_from_config(kv_cache_config) # type: ignore
|
||||
|
||||
def init_device(self):
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user