From 6de0982dd055abfe5332ede37c7c907d27f15298 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 6 Apr 2025 14:07:43 +0000 Subject: [PATCH] added Signed-off-by: rshaw@neuralmagic.com --- vllm/core/event_manager.py | 108 +++++ .../device_communicators/kv_rearrange.py | 110 +++++ vllm/distributed/device_communicators/nixl.py | 379 ++++++++++++++++++ .../kv_connector/dynamo_connector.py | 350 ++++++++++++++++ .../kv_transfer/kv_pipe/dynamo_nccl_pipe.py | 124 ++++++ vllm/remote_prefill.py | 67 ++++ 6 files changed, 1138 insertions(+) create mode 100644 vllm/core/event_manager.py create mode 100644 vllm/distributed/device_communicators/kv_rearrange.py create mode 100644 vllm/distributed/device_communicators/nixl.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py create mode 100644 vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py create mode 100644 vllm/remote_prefill.py diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py new file mode 100644 index 0000000000000..a27af5808a42d --- /dev/null +++ b/vllm/core/event_manager.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +import ctypes +import logging +import uuid +from ctypes import c_char_p, c_size_t, c_uint32, c_void_p, c_int64 +from typing import Optional + +from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash + +logger = logging.getLogger(__name__) + + +class DynamoResult: + OK = 0 + ERR = 1 + + +class KVCacheEventManager: + + def __init__(self, namespace: str, component: str, worker_id: int, + lib_path: str, kv_block_size: int): + self.lib = None + + try: + self.lib = ctypes.CDLL(lib_path) + self.lib.dynamo_llm_init.argtypes = [ + c_char_p, + c_char_p, + c_int64, + c_uint32, + ] + self.lib.dynamo_llm_init.restype = c_uint32 + + result = self.lib.dynamo_llm_init( + namespace.encode(), component.encode(), worker_id, kv_block_size + ) + if result == DynamoResult.OK: + logger.info( + "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events" + ) + else: + logger.info("KVCacheEventManager initialization failed!") + + except Exception as e: + print(f"Failed to load {lib_path}") + raise e + + self.lib.dynamo_kv_event_publish_stored.argtypes = [ + ctypes.c_uint64, # event_id + ctypes.POINTER(ctypes.c_uint32), # token_ids + ctypes.POINTER(ctypes.c_size_t), # num_block_tokens + ctypes.POINTER(ctypes.c_uint64), # block_ids + ctypes.c_size_t, # num_blocks + ctypes.POINTER(ctypes.c_uint64), # parent_hash + ctypes.c_uint64, # lora_id + ] + self.lib.dynamo_kv_event_publish_stored.restype = ctypes.c_uint32 # dynamo_llm_result_t + + self.lib.dynamo_kv_event_publish_removed.argtypes = [ + ctypes.c_uint64, # event_id + ctypes.POINTER(ctypes.c_uint64), # block_ids + ctypes.c_size_t, # num_blocks + ] + self.lib.dynamo_kv_event_publish_removed.restype = ctypes.c_uint32 # dynamo_llm_result_t + + self.event_id_counter = 0 + + def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock], + block: PrefixCachingBlock): + token_ids_arr = (ctypes.c_uint32 * + len(block.token_ids))(*block.token_ids) + num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids)) + block_hash = (ctypes.c_uint64 * 1)(block.content_hash) + parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash) + if parent is not None else None) + + # Publish the event + result = self.lib.dynamo_kv_event_publish_stored( + self.event_id_counter, # uint64_t event_id + token_ids_arr, # const uint32_t *token_ids + num_block_tokens, # const uintptr_t *num_block_tokens + block_hash, # const uint64_t *block_ids + 1, # uintptr_t num_blocks + parent_hash, # const uint64_t *parent_hash + 0, # uint64_t lora_id + ) + + if result == DynamoResult.OK: + logger.debug(f"Store - Published KV Event: {block.content_hash}") + else: + logger.debug( + f"Store - Failed to Publish KV Event: {block.content_hash}") + + self.event_id_counter += 1 + + def enqueue_removed_event(self, block_hash: PrefixHash): + result = self.lib.dynamo_kv_event_publish_removed( + self.event_id_counter, + (ctypes.c_uint64 * 1)(block_hash), + 1, + ) + + if result == DynamoResult.OK: + logger.debug(f"Remove - Published KV Event: {block_hash}") + else: + logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}") + + self.event_id_counter += 1 diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py new file mode 100644 index 0000000000000..b9485bd5599ae --- /dev/null +++ b/vllm/distributed/device_communicators/kv_rearrange.py @@ -0,0 +1,110 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def rearrange_kernel_read( + t1_ptr, + t2_ptr, + N, + B, + H, + C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + curr_n = offsets // block_size + curr_b = offsets // token_size % B + curr_h = offsets // C % H + curr_c = offsets % C + + src_pos = offsets + + tp_group = curr_h * d // H + dst_h = curr_h % (H // d) + tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c + + dst_pos = tensor_subset_size * tp_group + tp_group_offset + + tl.store(t1_ptr + src_pos, tl.load(t2_ptr + dst_pos)) + +@triton.jit +def rearrange_kernel_write( + t1_ptr, + t2_ptr, + N, + B, + H, + C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + curr_n = offsets // block_size + curr_b = offsets // token_size % B + curr_h = offsets // C % H + curr_c = offsets % C + + src_pos = offsets + + tp_group = curr_h * d // H + dst_h = curr_h % (H // d) + tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c + + dst_pos = tensor_subset_size * tp_group + tp_group_offset + + tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos)) + + + +def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int, direction: str): + N, B, H, C = t1.shape + + assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source" + assert H % d == 0, "H must be divisible by d" + + block_size = B * H * C + token_size = H * C + tensor_size = N * block_size + tensor_subset_size = tensor_size // d + + BLOCK_SIZE = 1024 + grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,) + + if direction == "read": + rearrange_kernel_read[grid]( + t1, t2, + N, B, H, C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE=BLOCK_SIZE + ) + elif direction == "write": + rearrange_kernel_write[grid]( + t1, t2, + N, B, H, C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE=BLOCK_SIZE + ) + else: + raise ValueError(f"Invalid direction: {direction}") \ No newline at end of file diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py new file mode 100644 index 0000000000000..a8bd202f55587 --- /dev/null +++ b/vllm/distributed/device_communicators/nixl.py @@ -0,0 +1,379 @@ +import torch +from typing import List, Tuple +from vllm.config import VllmConfig +from vllm.logger import init_logger +import msgspec +import time +import uuid +from collections import defaultdict +from .kv_rearrange import rearrange_tensors + +logger = init_logger(__name__) + +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +try: + from nixl._api import nixl_agent as NixlWrapper + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + +class NixlMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + engine_id: str + agent_metadata: List[bytes] + kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values + num_blocks: int + + +class DynamoNixlConnector: + def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int): + self.vllm_config = vllm_config + if NixlWrapper is None: + logger.error("NIXL is not available") + raise RuntimeError("NIXL is not available") + logger.info("Initializing NIXL wrapper") + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + + self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer + + self.num_layers = None + self.num_blocks = None + self.num_heads = None + self.block_len = None + self.kv_caches = None + self.kv_caches_base_addr = {} + self.kv_cache_shape = {} + + self._registered_descs = [] + self._remote_agents = {} + self.engine_id = engine_id + self.rank = rank + self._tp_size = {} + self.src_xfer_side_handles = {} + self.dst_xfer_side_handles = defaultdict(dict) + self.dst_num_blocks = {} + + self._transfers = defaultdict(list) + + + self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size + + + @property + def agent_name(self): + return self.nixl_wrapper.name + + def register_kv_caches(self, kv_caches: List[torch.Tensor]): + _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape + self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size() + logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) + self.num_layers = len(kv_caches) + self.num_blocks = num_blocks + self.num_heads = num_heads + self.kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + for key_cache, value_cache in kv_caches: + base_addr = key_cache.data_ptr() + region_len = 2 * num_blocks * self.block_len + caches_data.append((base_addr, region_len, self.rank, "")) + kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr())) + + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + self._registered_descs.append(descs) + + def get_agent_metadata(self): + return self.nixl_wrapper.get_agent_metadata() + + def shutdown(self): + for descs_list in self._registered_descs: + self.nixl_wrapper.deregister_memory(descs_list) + for agent_names in self._remote_agents.values(): + for agent_name in agent_names: + self.nixl_wrapper.remove_remote_agent(agent_name) + for src_xfer_side_handle in self.src_xfer_side_handles.values(): + self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) + for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): + for dst_xfer_side_handle in dst_xfer_side_handles.values(): + self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) + + def _get_ranges(self, block_ids): + # This function should return a list of ranges of block ids that are contiguous + # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]] + # The ranges are sorted by the starting block id + # The function should also make sure that the block ids are contiguous + # If the block ids are not contiguous, the function should raise an error + ranges = [] + for i in range(len(block_ids)): + if i == 0 or block_ids[i] != block_ids[i-1] + 1: + ranges.append([block_ids[i], block_ids[i]]) + else: + ranges[-1][1] = block_ids[i] + return ranges + + def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None): + + if layer_ids == "all": + layer_ids = list(range(self.num_layers)) + if block_ids == "all": + block_ids = list(range(self.num_blocks)) + + descs_ids = [] + + + if i is not None: + num_blocks = self.num_blocks + for layer_id in layer_ids: + for is_value in [0, 1]: + staging_range_idx = 0 + for block_id in block_ids: + if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]: + staging_range_idx += 1 + start_offset = staging_ranges[staging_range_idx][0] + i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1) + descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset)) + else: + num_blocks = self.dst_num_blocks[engine_id] + for layer_id in layer_ids: + for is_value in [0, 1]: + for block_id in block_ids: + descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id) + return descs_ids + + def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False): + # This function should return a list of ranges for both src and dst so that corresponding ranges are the same length + # For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]] + # The function should return ([[0, 2], [4, 6], [7, 8]], [[1, 3], [5, 7], [9, 10]]) + src_overlapping_ranges, dst_overlapping_ranges = [], [] + + original_src_ranges = [] + org_src_range = tuple(src_ranges[0]) + + src_idx, dst_idx = 0, 0 + while src_idx < len(src_ranges) and dst_idx < len(dst_ranges): + src_range = src_ranges[src_idx] + dst_range = dst_ranges[dst_idx] + + # Calculate the length of each range + src_len = src_range[-1] - src_range[0] + 1 + dst_len = dst_range[-1] - dst_range[0] + 1 + + # If ranges have the same length, add them directly + if src_len == dst_len: + src_overlapping_ranges.append([src_range[0], src_range[-1]]) + dst_overlapping_ranges.append([dst_range[0], dst_range[-1]]) + original_src_ranges.append(org_src_range) + src_idx += 1 + dst_idx += 1 + if src_idx < len(src_ranges): + org_src_range = tuple(src_ranges[src_idx]) + # If source range is longer, split it + elif src_len > dst_len: + src_overlapping_ranges.append([src_range[0], src_range[0] + dst_len - 1]) + dst_overlapping_ranges.append([dst_range[0], dst_range[-1]]) + original_src_ranges.append(org_src_range) + # Update source range for next iteration + src_ranges[src_idx] = [src_range[0] + dst_len, src_range[-1]] + dst_idx += 1 + # If destination range is longer, split it + else: # src_len < dst_len + src_overlapping_ranges.append([src_range[0], src_range[-1]]) + dst_overlapping_ranges.append([dst_range[0], dst_range[0] + src_len - 1]) + original_src_ranges.append(org_src_range) + # Update destination range for next iteration + dst_ranges[dst_idx] = [dst_range[0] + src_len, dst_range[-1]] + src_idx += 1 + if src_idx < len(src_ranges): + org_src_range = tuple(src_ranges[src_idx]) + if return_original_src_ranges: + return src_overlapping_ranges, dst_overlapping_ranges, original_src_ranges + return src_overlapping_ranges, dst_overlapping_ranges + + def read_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id): + logger.debug("Reading %d blocks from %s to %s", len(local_block_ids), self.agent_name, dst_engine_id) + + assert len(local_block_ids) == len(staging_block_ids) == len(remote_block_ids) + + if len(local_block_ids) == 0: + logger.debug("No blocks to read") + return + + start_time = time.perf_counter() + + local_ranges = self._get_ranges(local_block_ids) + staging_ranges = self._get_ranges(staging_block_ids) + + local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges) + + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] + remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids) + local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] + handles = [] + + logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000) + create_xfer_start_time = time.perf_counter() + + for i in range(tp_multiplier): + staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges) + assert len(staging_block_descs_ids) == len(remote_block_descs_ids) + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i] + handle = self.nixl_wrapper.make_prepped_xfer("READ", local_xfer_side_handle, staging_block_descs_ids, + remote_xfer_side_handle, remote_block_descs_ids, + "") + handles.append(handle) + status = self.nixl_wrapper.transfer(handle) + + logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000) + + transfer_start_time = time.perf_counter() + + for handle in handles: + while (status := self.nixl_wrapper.check_xfer_state(handle)) != "DONE": + if status == "PROC": + time.sleep(0.001) + else: + raise RuntimeError("Read transfer failed with state %s", status) + # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors? + + logger.debug("Time to transfer: %s ms", (time.perf_counter() - transfer_start_time) * 1000) + + rearrange_start_time = time.perf_counter() + + for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): + logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range) + for kv_cache in self.kv_caches: + for cache in kv_cache: + rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read") + + logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000) + logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000) + + def write_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id, notify_msg): + logger.debug("Writing %d blocks to %s from %s with notify message %s", len(local_block_ids), dst_engine_id, self.agent_name, notify_msg) + + # hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last + # isl[-1] token is calculated in the first iteration in decode. + # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \ + # one less block due to the missing last token. + remote_block_ids = remote_block_ids[:len(local_block_ids)] + + assert len(staging_block_ids) == len(local_block_ids) + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] + + if len(local_block_ids) == 0: + logger.debug("No blocks to write") + for i in range(tp_multiplier): + self.nixl_wrapper.send_notif(self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i], notify_msg) + return + + start_time = time.perf_counter() + + local_ranges = self._get_ranges(local_block_ids) + staging_ranges = self._get_ranges(staging_block_ids) + + local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges) + + for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): + logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range) + for kv_cache in self.kv_caches: + for cache in kv_cache: + rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write") + + logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000) + + create_xfer_start_time = time.perf_counter() + + # getting block descs ids + remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids) + local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] + + for i in range(tp_multiplier): + staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges) + assert len(staging_block_descs_ids) == len(remote_block_descs_ids) + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i] + handle = self.nixl_wrapper.make_prepped_xfer("WRITE", local_xfer_side_handle, staging_block_descs_ids, + remote_xfer_side_handle, remote_block_descs_ids, + notify_msg) + self._transfers[notify_msg].append(handle) + status = self.nixl_wrapper.transfer(handle) + + logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000) + + transfer_start_time = time.perf_counter() + logger.debug("Total time for write: %s ms", (time.perf_counter() - start_time) * 1000) + + def get_notifs(self): + return self.nixl_wrapper.update_notifs() + + def get_new_notifs(self): + return self.nixl_wrapper.get_new_notifs() + + def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks): + self._tp_size[engine_id] = agent_tp + agent_names = [] + for agent_meta in agent_metadata: + agent_name = self.nixl_wrapper.add_remote_agent(agent_meta) + agent_names.append(agent_name) + self._remote_agents[engine_id] = agent_names + self.kv_caches_base_addr[engine_id] = kv_caches_base_addr + + tp_multiplier = self._tp_size[engine_id] // self._tp_size[self.engine_id] + assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}" + + logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier) + dst_block_len = self.block_len // tp_multiplier + if tp_multiplier not in self.src_xfer_side_handles: + # create descs and xfer side handles + blocks_data = [] + for layer_id in range(self.num_layers): + for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]: + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + for i in range(tp_multiplier): + tp_multiplier_offset = i * dst_block_len + blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i) + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_dlist("", descs) + + # create dst xfer side handles + self.dst_num_blocks[engine_id] = num_blocks + for i in range(tp_multiplier): + blocks_data = [] + for layer_id in range(self.num_layers): + for base_addr in self.kv_caches_base_addr[engine_id][self.rank * tp_multiplier + i][layer_id]: + for block_id in range(num_blocks): + block_offset = block_id * dst_block_len + blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i)) + logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i) + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_dlist(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs) + + return agent_names + + def get_done_tranfers(self) -> List[str]: + done_req_ids = [] + for req_id, handles in self._transfers.items(): + running_reqs = [] + for handle in handles: + xfer_state = self.nixl_wrapper.check_xfer_state(handle) + if xfer_state == "DONE": + # self.nixl_wrapper.release_xfer_handle(handle) # TODO ptarasiewicz: why abort is throwing errors? + continue + if xfer_state == "PROC": + running_reqs.append(handle) + else: + raise RuntimeError("Transfer failed with state %s", xfer_state) + if len(running_reqs) == 0: + done_req_ids.append(req_id) + else: + self._transfers[req_id] = running_reqs + return done_req_ids diff --git a/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py b/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py new file mode 100644 index 0000000000000..7b3344f8a0a09 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py @@ -0,0 +1,350 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Simple KV Cache Connector for Distributed Machine Learning Inference + +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache +producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or +MooncakePipe. + +But the logic can be extended to support other pipe and lookup buffer. +""" +import re +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.config import VllmConfig, KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.utils import StatelessProcessGroup +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class DynamoConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + world_group, + ): + + self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size + self.rank = rank + + if self.config.kv_connector != "DynamoNcclConnector": + raise NotImplementedError("Only DynamoNcclConnector is supported by the DynamoConnector class") + + from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( + PyNcclPipe) + from vllm.distributed.kv_transfer.kv_pipe.dynamo_nccl_pipe import ( + DynamoNcclDataPlane) + + logger.info( + "Initializing DynamoNcclConnector under kv_transfer_config %s", + self.config) + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.producer_data_pipe: PyNcclPipe + self.consumer_data_pipe: PyNcclPipe + self.producer_signal_pipe: PyNcclPipe + self.consumer_signal_pipe: PyNcclPipe + + self._broadcast_and_enhance_kv_config(rank, config, world_group) + + self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) + self.tp_size = config.parallel_config.tensor_parallel_size + + # 2 pipes for every rank in the world + if self.config.is_kv_producer: + port_offset_base = rank + 1 + else: + port_offset_base = rank // self.config.tensor_parallel_multiplier + 1 + + + self.local_kv_rank = rank % self.config.tensor_parallel_multiplier + self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config) + + self.data_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + + self.data_plane = DynamoNcclDataPlane( + data_pipe=self.data_pipe, + port=self._get_data_plane_port(self.global_kv_rank), + ) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + + model_config = model_executable.model.config + is_deepseek = "deepseek" in model_config.architectures[0].lower() + if not is_deepseek: + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(hidden_size / num_attention_heads) + else: + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(4.5 * hidden_size / num_attention_heads) + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id) + decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config) + + for target_rank in range(self.config.tensor_parallel_multiplier): + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier + head_start = target_rank * num_heads_per_rank + head_end = head_start + num_heads_per_rank + + if not is_deepseek: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + else: + key_cache = kv_cache + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(torch.empty(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + decode_global_rank = decode_first_global_rank + target_rank + decode_port = self._get_data_plane_port(decode_global_rank) + partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos] + self._send(decode_hostname, decode_port, current_request_id, keys, values, + partial_hidden_or_intermediate_states) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + start_pos_list = [] + + model_config = model_executable.model.config + is_deepseek = "deepseek" in model_config.architectures[0].lower() + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + ret = self._recv(current_request_id) + keys: torch.Tensor = ret[0] + values: torch.Tensor = ret[1] + hidden: torch.Tensor = ret[2] + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + if not is_deepseek: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + else: + key_cache = kv_cache + copy_from =keys[i - model_executable.model.start_layer].to( + key_cache.device) + kv_cache[slot_mapping[start_pos:end_pos]] = copy_from + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.debug( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def close(self): + self.data_pipe.close() + # self.data_plane.close() + + @staticmethod + def parse_request_id(request_id: str) -> Tuple[str, int]: + # Regular expression to match the string hostname and integer decode_kv_rank + pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + if match: + # Extract the ranks + decode_hostname = match.group(1) + decode_rank = int(match.group(2)) + + return decode_hostname, decode_rank + raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank") + + def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor): + remote_address = f"{hostname}:{port}" + self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address) + self.data_plane.send_tensor(values, f"{request_id}_values", remote_address) + self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address) + + def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + keys = self.data_plane.recv_tensor(f"{request_id}_keys") + values = self.data_plane.recv_tensor(f"{request_id}_values") + hidden = self.data_plane.recv_tensor(f"{request_id}_hidden") + return keys, values, hidden + + def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + if kv_rank < config.kv_producers_parallel_size: + return kv_rank + + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier + + + def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + if kv_rank <= config.kv_producers_parallel_size: + return kv_rank * config.kv_producers_tensor_parallel_size + rank + + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank + + + def _get_data_plane_port(self, global_kv_rank: int) -> int: + return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank + + def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): + if rank == 0: + config_group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port, + rank=self.config.kv_rank, + world_size=self.config.kv_parallel_size, + ) + parallel_configs = config_group.all_gather_obj({ + "kv_role": self.config.kv_role, + "tensor_parallel_size": config.parallel_config.tensor_parallel_size, + "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, + }) + logger.debug("parallel_configs: %s", parallel_configs) + kv_config_enhanced = { + "kv_producers_tensor_parallel_size": None, + "kv_consumers_tensor_parallel_size": None, + "kv_producers_pipeline_parallel_size": None, + "kv_consumers_pipeline_parallel_size": None, + "kv_producers_parallel_size": 0, + } + for parallel_config in parallel_configs: + kv_role = parallel_config["kv_role"] + assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" + + if kv_role == "kv_producer": + kv_config_enhanced["kv_producers_parallel_size"] += 1 + if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: + kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] + kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] + else: + assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" + assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" + world_group.broadcast_object(kv_config_enhanced) + else: + kv_config_enhanced = world_group.broadcast_object() + logger.info("kv_config_enhanced: %s", kv_config_enhanced) + + self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"] + self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"] + self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] + self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] + self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] diff --git a/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py new file mode 100644 index 0000000000000..3ee0fa78f4aac --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py @@ -0,0 +1,124 @@ +import logging +import threading +import typing +import zmq +import socket +import time +import torch + +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe + + +logger = logging.getLogger(__name__) + + +class DynamoNcclDataPlane: + def __init__( + self, + data_pipe: PyNcclPipe, + hostname: str = "", + port: int = 0, + ) -> None: + + self.data_pipe = data_pipe + if not hostname: + hostname = socket.gethostname() + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + self.store = {} + self.context = zmq.Context() + self.rep_socket = self.context.socket(zmq.REP) + logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}") + self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}") + self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True) + self._listener_thread.start() + self.req_sockets = {} + logger.info(f"Rank {self.rank} connected to the server") + + @property + def rank(self): + return self.data_pipe.kv_group_rank + + def send_tensor( + self, + tensor: torch.Tensor, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ): + logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}") + return self._send_tensor(tensor, tensor_id, remote_address) + + def recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + ret = self._recv_tensor(tensor_id, remote_address) + return ret + + def _send_tensor( + self, + tensor: torch.Tensor, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ): + logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}") + if remote_address is None: + self.store[tensor_id] = tensor + else: + # tensor_shape = "_".join(str(dim) for dim in tensor.shape) + # tensor_dtype = str(tensor.dtype) + if remote_address not in self.req_sockets: + self.req_sockets[remote_address] = self.context.socket(zmq.REQ) + self.req_sockets[remote_address].connect(f"tcp://{remote_address}") + + req_socket = self.req_sockets[remote_address] + # req_socket.connect(f"tcp://{remote_address}") + req_socket.send_string(f"PUT {self.rank} {tensor_id}") + dst_rank = req_socket.recv_string() + logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}") + self.data_pipe.send_tensor(tensor, int(dst_rank)) + + def _recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + logger.debug(f"Rank {self.rank} receiving tensor") + if remote_address is not None: + raise NotImplementedError("Getting tensor from remote rank not implemented") + if tensor_id in self.store: + logger.debug(f"Popping tensor {tensor_id} from store") + future = self.store.pop(tensor_id) + tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait + logger.debug(f"Rank {self.rank} received tensor") + return tensor + + logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}") + time.sleep(0.001) + return self._recv_tensor(tensor_id, remote_address) + # raise NotImplementedError("Tensor not found in store") + + def _receive_tensor( + self, + tensor_id: str, + rank: int, + ): + future = self.data_pipe.recv_tensor(rank) + logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store") + self.store[tensor_id] = future + + def listen_for_requests(self): + while True: + cmd, rank, tensor_id = self.rep_socket.recv_string().split() + logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}") + self.rep_socket.send_string(f"{self.rank}") + if cmd == "GET": + raise NotImplementedError("Getting tensor from remote rank not implemented") + elif cmd == "PUT": + rank = int(rank) + # shape = [int(dim) for dim in shape.split("_")] + # dtype = getattr(torch, dtype) + self._receive_tensor(tensor_id, rank) diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py new file mode 100644 index 0000000000000..3f9711ef0e605 --- /dev/null +++ b/vllm/remote_prefill.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from typing import Callable, Optional, List +from enum import Enum + +import msgspec + +from vllm.sampling_params import SamplingParams + + +class RemotePrefillRequest( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + """The request data of one remote prefill output of a request. + + Args: + engine_id: The unique ID of the engine. + request_id: The unique ID of the request. + prompt_token_ids: The token IDs of the prompt. + sampling_params: The sampling parameters. + block_ids: The block IDs of the request. + computed_block_ids: The computed block IDs of the request. + """ + engine_id: str + request_id: str + prompt_token_ids: List[int] + sampling_params: SamplingParams + block_ids: List[int] + computed_block_ids: List[int] + + +class MemoryOpType(str, Enum): + WRITE = "WRITE" + READ = "READ" + + +class MemoryTransferRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """The request data of one memory transfer output of a request. + + Args: + request_id: The unique ID of the request. + """ + request_id: str + local_block_ids: List[int] + staging_block_ids: List[int] + remote_block_ids: List[int] + remote_engine_id: str + notify_msg: str + op_type: MemoryOpType + + +RemotePrefillRequestCallback = Callable[[RemotePrefillRequest], None] + + +@dataclass +class RemotePrefillParams: + """Remote prefill parameters for text generation.""" + is_remote_prefill: bool = False + is_remote_decode: bool = False + decode_block_ids: Optional[List[int]] = None + decode_computed_block_ids: Optional[List[int]] = None + decode_engine_id: Optional[str] = None + remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None \ No newline at end of file