diff --git a/vllm/config.py b/vllm/config.py index 9ba4975761245..a4df00193a980 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2620,6 +2620,9 @@ class KVTransferConfig(BaseModel): # The KV connector for vLLM to transmit KV caches between vLLM instances. kv_connector: Optional[str] = None + # Whether to use NIXL prepped xfer for KV cache transfer. + use_prepped_xfer: bool = True + # The device used by kv connector to buffer the KV cache. # Currently only support 'cuda'. kv_buffer_device: Optional[str] = "cuda" @@ -2629,7 +2632,7 @@ class KVTransferConfig(BaseModel): kv_buffer_size: float = 1e9 # Whether this vLLM instance produces, consumes KV cache, or both. Choices - # are 'kv_producer', 'kv_consumer', and 'both'. + # are 'kv_producer', 'kv_consumer', and 'kv_both'. kv_role: Optional[str] = None # The rank of this vLLM instance in the KV cache transfer. Typical value: @@ -2647,6 +2650,14 @@ class KVTransferConfig(BaseModel): # The KV connector port, used to build distributed connection kv_port: int = 14579 + + # This does not need to be set by the user. It is set by the connector. + kv_producers_parallel_size: Optional[int] = None + kv_producers_tensor_parallel_size: Optional[int] = None + kv_producers_pipeline_parallel_size: Optional[int] = None + kv_consumers_tensor_parallel_size: Optional[int] = None + kv_consumers_pipeline_parallel_size: Optional[int] = None + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2680,11 +2691,16 @@ class KVTransferConfig(BaseModel): f"Supported roles are `kv_producer`, `kv_consumer`, " f"and `kv_both`") - if self.kv_connector is not None and self.kv_role is None: + if self.kv_connector is not None and self.kv_connector != "DynamoNixlConnector" and self.kv_role is None: raise ValueError("Please specify kv_disagg_role when kv_connector " "is set, supported roles are `kv_producer`, " "`kv_consumer`, and `kv_both`") + if self.use_prepped_xfer is False: + logger.warning("`use_prepped_xfer` parameter is deprecated. All transfers will be done using prepped xfer.") + self.use_prepped_xfer = True + + @property def is_kv_transfer_instance(self) -> bool: return self.kv_connector is not None and \ @@ -2694,6 +2710,8 @@ class KVTransferConfig(BaseModel): def need_kv_parallel_group(self) -> bool: # for those database-based connector, vLLM does not need to create # parallel group, and in that case the kv parallel size will be 1. + if self.kv_connector == "DynamoNixlConnector": + return False return self.kv_connector is not None and self.kv_parallel_size > 1 @property @@ -2706,6 +2724,18 @@ class KVTransferConfig(BaseModel): return self.kv_connector is not None and \ self.kv_role in ["kv_consumer", "kv_both"] + @property + def tensor_parallel_multiplier(self) -> int: + return self.kv_consumers_tensor_parallel_size // self.kv_producers_tensor_parallel_size + + @property + def kv_consumers_parallel_size(self) -> int: + return self.kv_parallel_size - self.kv_producers_parallel_size + + @property + def kv_world_size(self) -> int: + return self.kv_producers_parallel_size + self.kv_consumers_parallel_size * self.tensor_parallel_multiplier + class CompilationLevel: # constants for the levels of the compilation process diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 359b5b263f689..d52ee050cb3b4 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -6,6 +6,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.core.event_manager import KVCacheEventManager from vllm.platforms import current_platform from vllm.utils import Device @@ -28,6 +29,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): num_gpu_blocks: int, num_cpu_blocks: int, block_size: int, + event_manager: Optional[KVCacheEventManager] = None, ) -> DeviceAwareBlockAllocator: """Creates a CpuGpuBlockAllocator instance with the specified configuration. @@ -64,6 +66,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): cpu_block_ids = block_ids[num_gpu_blocks:] if allocator_type == "naive": + assert event_manager is None, "Event API not supported with naive allocator." gpu_allocator: BlockAllocator = NaiveBlockAllocator( create_block=NaiveBlock, # type: ignore num_blocks=num_gpu_blocks, @@ -82,12 +85,14 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): num_blocks=num_gpu_blocks, block_size=block_size, block_ids=gpu_block_ids, + event_manager=event_manager, ) cpu_allocator = PrefixCachingBlockAllocator( num_blocks=num_cpu_blocks, block_size=block_size, block_ids=cpu_block_ids, + event_manager=event_manager, ) else: raise ValueError(f"Unknown allocator type {allocator_type=}") @@ -95,10 +100,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): return CpuGpuBlockAllocator( cpu_block_allocator=cpu_allocator, gpu_block_allocator=gpu_allocator, + event_manager=event_manager, ) def __init__(self, cpu_block_allocator: BlockAllocator, - gpu_block_allocator: BlockAllocator): + gpu_block_allocator: BlockAllocator, + event_manager: Optional[KVCacheEventManager] = None,): assert not ( cpu_block_allocator.all_block_ids & gpu_block_allocator.all_block_ids @@ -108,6 +115,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): Device.CPU: cpu_block_allocator, Device.GPU: gpu_block_allocator, } + self.event_manager = event_manager self._swap_mapping: Dict[int, int] = {} self._null_block: Optional[Block] = None diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index c388366b825f2..31ed7aa44ada0 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -2,7 +2,7 @@ from collections import deque from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union - +import heapq from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device @@ -38,7 +38,7 @@ class NaiveBlockAllocator(BlockAllocator): if block_ids is None: block_ids = range(num_blocks) - self._free_block_indices: Deque[BlockId] = deque(block_ids) + self._free_block_indices: List[BlockId] = list(block_ids) self._all_block_indices = frozenset(block_ids) assert len(self._all_block_indices) == num_blocks @@ -134,7 +134,8 @@ class NaiveBlockAllocator(BlockAllocator): if not self._free_block_indices: raise BlockAllocator.NoFreeBlocksError() - block_id = self._free_block_indices.popleft() + block_id = heapq.heappop(self._free_block_indices) + # TODO: figure out why sometime block_id is None self._refcounter.incr(block_id) return block_id @@ -148,7 +149,7 @@ class NaiveBlockAllocator(BlockAllocator): refcount = self._refcounter.decr(block_id) if refcount == 0: - self._free_block_indices.appendleft(block_id) + heapq.heappush(self._free_block_indices, block_id) def free(self, block: Block, keep_block_object: bool = False) -> None: # Release the physical block id diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 1ca9e49dac371..cd780f698859a 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -4,7 +4,7 @@ import sys from bisect import bisect_left from os.path import commonprefix from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, - Tuple) + Tuple, TYPE_CHECKING) from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) @@ -23,6 +23,9 @@ PrefixHash = int # then we know this block hasn't been accessed yet. _DEFAULT_LAST_ACCESSED_TIME = -1 +if TYPE_CHECKING: + from vllm.core.event_manager import KVCacheEventManager + logger = init_logger(__name__) @@ -80,6 +83,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): block_size: int, block_ids: Optional[Iterable[int]] = None, eviction_policy: EvictionPolicy = EvictionPolicy.LRU, + event_manager: Optional["KVCacheEventManager"] = None, ): if block_ids is None: block_ids = range(num_blocks) @@ -131,6 +135,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): self.metric_data = CacheMetricData() + self.event_manager = event_manager + + # Implements Block.Factory. def _create_block( self, prev_block: Optional[Block], @@ -337,6 +344,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): assert self._refcounter.get(_block_id) == 0 assert _block_id == block_id + if self.event_manager: + self.event_manager.enqueue_removed_event(content_hash_to_evict) + self._cached_blocks.pop(content_hash_to_evict) self._refcounter.incr(block_id) @@ -513,6 +523,10 @@ class PrefixCachingBlockAllocator(BlockAllocator): # Mark this block as touched so that it can be marked as # computed after the entire batch of sequences are scheduled. self._touched_blocks.add(block.block_id) + + if self.event_manager: + self.event_manager.enqueue_stored_event(block.prev_block, block) + return block.block_id # Reuse the cached content hash @@ -579,9 +593,11 @@ class PrefixCachingBlockAllocator(BlockAllocator): def mark_blocks_as_computed(self, block_ids: List[int]) -> None: # Mark all touched blocks as computed. - for block_id in self._touched_blocks: - self._block_tracker[block_id].computed = True - self._touched_blocks.clear() + for block_id in block_ids: + if block_id in self._touched_blocks: + logger.debug("Mark block as computed: %s", block_id) + self._block_tracker[block_id].computed = True + self._touched_blocks.remove(block_id) def _track_block_id(self, block_id: Optional[BlockId], computed: bool) -> None: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index c5b3b04f37ca3..21fe0fc884425 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -10,7 +10,10 @@ from vllm.core.block.interfaces import Block from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, LastAccessBlocksTracker) from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec +from vllm.core.event_manager import KVCacheEventManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.envs import (VLLM_KV_CAPI_PATH, VLLM_KV_COMPONENT, VLLM_KV_NAMESPACE, + VLLM_WORKER_ID) from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device @@ -60,6 +63,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): def __init__( self, + model_name: str, block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, @@ -91,11 +95,29 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): self.watermark_blocks = int(watermark * num_gpu_blocks) + kv_event_manager_params = [ + VLLM_WORKER_ID, VLLM_KV_CAPI_PATH, VLLM_KV_NAMESPACE, + VLLM_KV_COMPONENT + ] + set_kv_event_manager_params = len( + [param for param in kv_event_manager_params if param is not None]) + + if set_kv_event_manager_params == len(kv_event_manager_params): + self.event_manager = KVCacheEventManager( + namespace=VLLM_KV_NAMESPACE, + component=VLLM_KV_COMPONENT, + worker_id=VLLM_WORKER_ID, + lib_path=VLLM_KV_CAPI_PATH, + kv_block_size=block_size) + else: + self.event_manager = None + self.block_allocator = CpuGpuBlockAllocator.create( allocator_type="prefix_caching" if enable_caching else "naive", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=block_size, + event_manager=self.event_manager, ) self.block_tables: Dict[SeqId, BlockTable] = {} @@ -108,7 +130,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): def can_allocate(self, seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: + num_lookahead_slots: int = 0, + is_remote_decode: bool = False) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. @@ -121,6 +144,10 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): num_lookahead_slots=num_lookahead_slots, ) + # if remote decode, we need to allocate twice as many blocks for staging + if is_remote_decode: + num_required_blocks *= 2 + if seq_group.is_encoder_decoder(): encoder_seq = seq_group.get_encoder_seq() assert encoder_seq is not None diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f507847ad82cf..170a359f602fb 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -4,22 +4,22 @@ import enum import os import random import time +import copy from collections import deque from dataclasses import dataclass, field from typing import Callable, Deque, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union +from typing import Set, Tuple, Union, Any -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import ModelConfig, CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, - SequenceStatus) + SequenceStatus, SequenceStage) from vllm.utils import Device, PyObjectCache - logger = init_logger(__name__) # Test-only. If configured, decode is preempted with @@ -285,6 +285,7 @@ class SchedulerPrefillOutputs: # Ignored sequence groups. ignored_seq_groups: List[SequenceGroup] num_lookahead_slots: int + num_remote_prefill_groups: int @classmethod def create_empty(cls) -> "SchedulerPrefillOutputs": @@ -292,6 +293,7 @@ class SchedulerPrefillOutputs: seq_groups=[], ignored_seq_groups=[], num_lookahead_slots=0, + num_remote_prefill_groups=0, ) @@ -325,12 +327,14 @@ class Scheduler: def __init__( self, + model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], pipeline_parallel_size: int = 1, output_proc_callback: Optional[Callable] = None, ) -> None: + self.model_config = model_config self.scheduler_config = scheduler_config self.cache_config = cache_config # Note for LoRA scheduling: the current policy is extremely @@ -356,6 +360,7 @@ class Scheduler: # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( + model_name=self.model_config.served_model_name, block_size=self.cache_config.block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, @@ -371,6 +376,16 @@ class Scheduler: # Sequence groups in the SWAPPED state. # Contain decode requests that are swapped out. self.swapped: Deque[SequenceGroup] = deque() + + # Sequence groups in the REMOTE_PREFILLING state. + # Contain requests that are being prefilled by a remote worker. + self.remote_prefilling: Deque[SequenceGroup] = deque() + # Contain requests that are being prefilled by a local worker. + self.prefill_sending: Deque[SequenceGroup] = deque() + + self._remote_prefill_outputs: Dict[str, int] = {} + + # Sequence groups finished requests ids since last step iteration. # It lets the model know that any state associated with these requests # can and must be released after the current step. @@ -501,7 +516,7 @@ class Scheduler: def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( - self.swapped) != 0 + self.swapped) != 0 or len(self.remote_prefilling) != 0 or len(self.prefill_sending) != 0 def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_manager.get_prefix_cache_hit_rate(device) @@ -523,6 +538,8 @@ class Scheduler: budget: SchedulingBudget, curr_loras: Optional[Set[int]], enable_chunking: bool = False, + finished_prefills: Optional[Set[str]] = None, + finished_transfers: Optional[Set[str]] = None ) -> SchedulerRunningOutputs: """Schedule sequence groups that are running. @@ -537,6 +554,8 @@ class Scheduler: chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. + finished_remote_prefill_request_ids: Set of request ids of remote + prefills that have finished. Returns: SchedulerRunningOutputs. @@ -566,6 +585,38 @@ class Scheduler: preempted: List[SequenceGroup] = ret.preempted swapped_out: List[SequenceGroup] = ret.swapped_out + remote_prefilling_queue = self.remote_prefilling + leftover_remote_prefilling_sequences: Deque[SequenceGroup] = deque() + while remote_prefilling_queue: + seq_group = remote_prefilling_queue.popleft() + if seq_group.request_id not in finished_prefills: + leftover_remote_prefilling_sequences.append(seq_group) + continue + + else: + finished_prefills.remove(seq_group.request_id) + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + # we computed all but the last token in prefill, we need to decode the first token on decode + seq_group.update_num_computed_tokens(seq.get_len() - 1) + seq.status = SequenceStatus.RUNNING + seq.data._stage = SequenceStage.DECODE + self.running.appendleft(seq_group) + remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences) + + remote_transfers_queue = self.prefill_sending + leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque() + while remote_transfers_queue: + seq_group = remote_transfers_queue.popleft() + if seq_group.request_id not in finished_transfers: + leftover_remote_transfers_sequences.append(seq_group) + else: + finished_transfers.remove(seq_group.request_id) + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + self.free_seq(seq) + remote_transfers_queue.extendleft(leftover_remote_transfers_sequences) + running_queue = self.running assert len(self._async_stopped) == 0 while running_queue: @@ -925,6 +976,7 @@ class Scheduler: seq_groups: List[ScheduledSequenceGroup] = [] waiting_queue = self.waiting + num_remote_prefill_groups = 0 leftover_waiting_sequences: Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and waiting_queue: @@ -961,8 +1013,10 @@ class Scheduler: True, enable_chunking) # If the sequence group cannot be allocated, stop. + is_remote_decode = seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode can_allocate = self.block_manager.can_allocate( - seq_group, num_lookahead_slots=num_lookahead_slots) + seq_group, num_lookahead_slots=num_lookahead_slots, + is_remote_decode=is_remote_decode) if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: @@ -1008,7 +1062,18 @@ class Scheduler: if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) waiting_queue.popleft() - self._allocate_and_set_running(seq_group) + + seq_group_copy = copy.deepcopy(seq_group) + seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1 + + logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id) + logger.debug("Seq id: %s", seq_group.seqs[0].seq_id) + is_remote_prefill = self._allocate_and_set_running_or_remote_prefill(seq_group) + num_remote_prefill_groups += is_remote_prefill + if is_remote_decode: + logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id) + self._allocate_and_set_running_or_remote_prefill(seq_group_copy) + self.prefill_sending.append(seq_group_copy) if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] @@ -1046,9 +1111,11 @@ class Scheduler: seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking)) + is_prefill=True, enable_chunking=enable_chunking), + num_remote_prefill_groups=num_remote_prefill_groups + ) - def _schedule_default(self) -> SchedulerOutputs: + def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs: """Schedule queued requests. The current policy is designed to optimize the throughput. First, @@ -1066,9 +1133,13 @@ class Scheduler: for seq_group in self.running: budget.add_num_seqs(seq_group.request_id, seq_group.get_max_num_running_seqs()) - curr_loras = set( + for seq_group in self.remote_prefilling: + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) + + curr_loras = (set( seq_group.lora_int_id for seq_group in self.running - if seq_group.lora_int_id > 0) if self.lora_enabled else None + if seq_group.lora_int_id > 0) if self.lora_enabled else None) prefills = SchedulerPrefillOutputs.create_empty() running_scheduled = SchedulerRunningOutputs.create_empty() @@ -1090,7 +1161,9 @@ class Scheduler: if len(prefills.seq_groups) == 0: running_scheduled = self._schedule_running(budget, curr_loras, - enable_chunking=False) + enable_chunking=False, + finished_prefills=finished_prefills, + finished_transfers=finished_transfers) # If any sequence group is preempted, do not swap in any sequence # group. because it means there's no slot for new running requests. @@ -1106,7 +1179,12 @@ class Scheduler: self.waiting.extendleft(running_scheduled.preempted) # Update new running requests. if len(prefills.seq_groups) > 0: - self.running.extend([s.seq_group for s in prefills.seq_groups]) + for s in prefills.seq_groups: + seq_group = s.seq_group + if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: + self.remote_prefilling.append(seq_group) + else: + self.running.append(seq_group) self.running.extend(running_scheduled.decode_seq_groups_list) @@ -1248,12 +1326,14 @@ class Scheduler: len(running_scheduled.swapped_out)), ) - def _schedule(self) -> SchedulerOutputs: + def _schedule(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs: """Schedule queued requests.""" if self.scheduler_config.chunked_prefill_enabled: + if finished_prefills or finished_transfers: + raise ValueError("Chunked prefill does not support remote prefills") return self._schedule_chunked_prefill() else: - return self._schedule_default() + return self._schedule_default(finished_prefills, finished_transfers) def _can_append_slots(self, seq_group: SequenceGroup, enable_chunking: bool) -> bool: @@ -1287,14 +1367,16 @@ class Scheduler: return no_single_seq def schedule( - self + self, + finished_prefills: Optional[Set[str]] = None, + finished_transfers: Optional[Set[str]] = None ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. - scheduler_start_time = time.perf_counter() - scheduler_outputs: SchedulerOutputs = self._schedule() + scheduler_start_time = time.perf_counter() + scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills, finished_transfers) now = time.time() if not self.cache_config.enable_prefix_caching: @@ -1333,7 +1415,8 @@ class Scheduler: encoder_seq_data = None cross_block_table = None - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + running_or_remote_prefilling_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + seq_group.get_seqs(status=SequenceStatus.REMOTE_PREFILLING) + for seq in running_or_remote_prefilling_seqs: seq_id = seq.seq_id seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) @@ -1342,7 +1425,9 @@ class Scheduler: if self.cache_config.enable_prefix_caching: common_computed_block_nums = ( self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING))) + running_or_remote_prefilling_seqs + ) + ) do_sample = True is_prompt = seq_group.is_prefill() @@ -1364,9 +1449,30 @@ class Scheduler: < seqs[0].data.get_len()): do_sample = False + is_remote_prefill = False + if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: + is_remote_prefill = True + logger.debug("Remote prefill, computed block nums: %s", common_computed_block_nums) + if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode: + block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids + + # Since we know that prefill is scheduled we can + # assume that the blocks computed on decode + # will be fetched by the time we run prefill + logger.debug("Computed decode blocks: %s", seq_group.remote_prefill_params.decode_computed_block_ids) + if seq_group.remote_prefill_params.decode_computed_block_ids: + computed_block_ids = set(seq_group.remote_prefill_params.decode_computed_block_ids) + prefill_block_ids = block_tables[seq_group.seqs[0].seq_id] + prefill_fetched_block_ids = [prefill_block_ids[i] for i, block_id in enumerate(seq_group.remote_prefill_params.decode_block_ids) if block_id in computed_block_ids and i < len(prefill_block_ids)] + + assert len(common_computed_block_nums) == 0, "common_computed_block_nums should be empty for remote prefill as it doesn't suport prefix caching" + common_computed_block_nums = prefill_fetched_block_ids + + # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. if is_first_prefill or not self.scheduler_config.send_delta_data: + logger.debug("Assinged blocks: %s", block_tables) seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=is_prompt, @@ -1392,6 +1498,7 @@ class Scheduler: if scheduler_outputs.num_prefill_groups > 0 else None, mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, + do_remote_prefill=is_remote_prefill, ) else: # When SPMD mode is enabled, we only send delta data except for @@ -1490,11 +1597,17 @@ class Scheduler: self._async_stopped.clear() - def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: + def _allocate_and_set_running_or_remote_prefill(self, seq_group: SequenceGroup) -> bool: self.block_manager.allocate(seq_group) + is_remote_prefill = False for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): - seq.status = SequenceStatus.RUNNING - + if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: + seq.status = SequenceStatus.REMOTE_PREFILLING + is_remote_prefill = True + else: + seq.status = SequenceStatus.RUNNING + return is_remote_prefill + def _append_slots(self, seq_group: SequenceGroup, blocks_to_copy: List[Tuple[int, int]], diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index fe480533458b8..c82fda805e719 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -27,13 +27,13 @@ class KVConnectorFactory: @classmethod def create_connector(cls, rank: int, local_rank: int, - config: "VllmConfig") -> KVConnectorBase: + config: "VllmConfig", world_group) -> KVConnectorBase: connector_name = config.kv_transfer_config.kv_connector if connector_name not in cls._registry: raise ValueError(f"Unsupported connector type: {connector_name}") connector_cls = cls._registry[connector_name]() - return connector_cls(rank, local_rank, config) + return connector_cls(rank, local_rank, config, world_group) # Register various connectors here. @@ -48,3 +48,8 @@ KVConnectorFactory.register_connector( "MooncakeConnector", "vllm.distributed.kv_transfer.kv_connector.simple_connector", "SimpleConnector") + +KVConnectorFactory.register_connector( + "DynamoNcclConnector", + "vllm.distributed.kv_transfer.kv_connector.dynamo_connector", + "DynamoConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 2033e9762ac0b..ddebb68e6751a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -8,13 +8,15 @@ MooncakePipe. But the logic can be extended to support other pipe and lookup buffer. """ +import re from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch from vllm import _custom_ops as ops -from vllm.config import VllmConfig +from vllm.config import VllmConfig, KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( SimpleBuffer) from vllm.logger import init_logger @@ -33,6 +35,7 @@ class SimpleConnector(KVConnectorBase): rank: int, local_rank: int, config: VllmConfig, + world_group, ): self.config = config.kv_transfer_config @@ -71,20 +74,31 @@ class SimpleConnector(KVConnectorBase): self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] - # 2 pipes for every rank in the world - port_offset_base = 2 * rank + self._broadcast_and_enhance_kv_config(rank, config, world_group) + self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) + self.tp_size = config.parallel_config.tensor_parallel_size + + # 2 pipes for every rank in the world + if self.config.is_kv_producer: + port_offset_base = 2 * rank + 1 + else: + port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1 + + self.local_kv_rank = rank % self.config.tensor_parallel_multiplier # In disaggregated prefill, the prefill vLLM only uses send pipe # and the decode vLLM only uses recv pipe if self.config.is_kv_producer: if self.config.kv_connector == "PyNcclConnector": self.producer_data_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base, ) self.producer_signal_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base + 1, @@ -108,11 +122,13 @@ class SimpleConnector(KVConnectorBase): # its recv pipe to the send pipe of KV producder if self.config.kv_connector == "PyNcclConnector": self.consumer_data_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base, ) self.consumer_signal_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base + 1, @@ -131,21 +147,25 @@ class SimpleConnector(KVConnectorBase): self.config.kv_buffer_size, ) - def select(self, input_tokens: Optional[torch.Tensor], + def select(self, source_rank: int, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + logger.info("Selecting KV caches and hidden states for source rank %d", source_rank) + assert self.consumer_buffer is not None, "Please initialize the "\ "consumer buffer before calling select." - return self.consumer_buffer.drop_select(input_tokens, roi) + return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: + logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank) + assert self.producer_buffer is not None, "Please initialize the "\ "producer buffer before calling insert." - self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden) def send_kv_caches_and_hidden_states( self, @@ -161,12 +181,20 @@ class SimpleConnector(KVConnectorBase): slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer + request_ids = list(model_input.request_ids_to_seq_ids.keys()) model_config = model_executable.model.config - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - head_size = int(hidden_size / num_attention_heads) + is_deepseek = "deepseek" in model_config.architectures[0].lower() + if not is_deepseek: + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(hidden_size / num_attention_heads) + else: + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(4.5 * hidden_size / num_attention_heads) # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance @@ -175,27 +203,40 @@ class SimpleConnector(KVConnectorBase): start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + _, decode_kv_rank = self.parse_request_id(current_request_id) + starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config) - keys, values = [], [] + for target_rank in range(self.config.tensor_parallel_multiplier): - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] + keys, values = [], [] - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) + num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier + head_start = target_rank * num_heads_per_rank + head_end = head_start + num_heads_per_rank - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) + if not is_deepseek: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + else: + key_cache = kv_cache + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(torch.empty(0)) - self.insert(current_tokens, - torch.ones_like(current_tokens, - dtype=bool), keys, values, - hidden_or_intermediate_states[start_pos:end_pos]) + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + self.insert(starting_kv_group_rank, target_rank, current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) @@ -215,6 +256,7 @@ class SimpleConnector(KVConnectorBase): input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) hidden_or_intermediate_states_for_one_req = [] @@ -222,6 +264,9 @@ class SimpleConnector(KVConnectorBase): num_computed_tokens_list = [] start_pos_list = [] + model_config = model_executable.model.config + is_deepseek = "deepseek" in model_config.architectures[0].lower() + # enumerate different requests # FIXME(Kuntai): This impl assumes that all requests are prefill. for idx, slen in enumerate(seq_lens): @@ -229,13 +274,15 @@ class SimpleConnector(KVConnectorBase): start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + prefill_rank, _ = self.parse_request_id(current_request_id) num_tokens = slen # collecting data for rebuilding the input input_tokens_list.append(current_tokens) start_pos_list.append(start_pos) - ret = self.select(current_tokens, + ret = self.select(prefill_rank, current_tokens, torch.ones_like(current_tokens, dtype=bool)) if ret[0] is None: # didn't find any match. @@ -267,19 +314,25 @@ class SimpleConnector(KVConnectorBase): kv_cache = kv_caches[i - model_executable.model.start_layer] layer = model_executable.model.layers[i] - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - keys[i - model_executable.model.start_layer].to( - key_cache.device), - values[i - model_executable.model.start_layer].to( - value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) + if not is_deepseek: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + else: + key_cache = kv_cache + copy_from =keys[i - model_executable.model.start_layer].to( + key_cache.device) + kv_cache[slot_mapping[start_pos:end_pos]] = copy_from hidden_or_intermediate_states_for_one_req.append(hidden) @@ -312,3 +365,77 @@ class SimpleConnector(KVConnectorBase): # MooncakePipe reuses data_pipe for signal_pipe, so we only have to # close the data_pipe. pass + + @staticmethod + def parse_request_id(request_id): + # Regular expression to match the ranks + pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + + if match: + # Extract the ranks + prefill_rank = int(match.group(1)) + decode_rank = int(match.group(2)) + + return prefill_rank, decode_rank + else: + return None, None + + + + def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + if kv_rank < config.kv_producers_parallel_size: + return kv_rank + + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier + + def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): + if rank == 0: + if self.config.kv_connector == "PyNcclConnector": + config_group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port, + rank=self.config.kv_rank, + world_size=self.config.kv_parallel_size, + ) + parallel_configs = config_group.all_gather_obj({ + "kv_role": self.config.kv_role, + "tensor_parallel_size": config.parallel_config.tensor_parallel_size, + "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, + }) + logger.debug("parallel_configs: %s", parallel_configs) + kv_config_enhanced = { + "kv_producers_tensor_parallel_size": None, + "kv_consumers_tensor_parallel_size": None, + "kv_producers_pipeline_parallel_size": None, + "kv_consumers_pipeline_parallel_size": None, + "kv_producers_parallel_size": 0, + } + for parallel_config in parallel_configs: + kv_role = parallel_config["kv_role"] + assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" + + if kv_role == "kv_producer": + kv_config_enhanced["kv_producers_parallel_size"] += 1 + if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: + kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] + kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] + else: + assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" + assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" + world_group.broadcast_object(kv_config_enhanced) + + else: + raise NotImplementedError("MooncakeConnector is not supported in Dynamo patch") + else: + kv_config_enhanced = world_group.broadcast_object() + logger.info("kv_config_enhanced: %s", kv_config_enhanced) + + self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"] + self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"] + self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] + self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] + self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 5e1b62352d14c..b45068775f431 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -12,7 +12,8 @@ import threading import time from collections import deque -from typing import Deque, List, Optional, Union +from concurrent.futures import ThreadPoolExecutor +from typing import Deque, List, Optional, Union, Dict import torch @@ -46,7 +47,7 @@ class SimpleBuffer(KVLookupBufferBase): self.buffer_lock = threading.Lock() self.signal_pipe = signal_pipe self.data_pipe = data_pipe - self.request_handling_thread: Optional[threading.Thread] = None + self.request_handling_thread: Optional[ThreadPoolExecutor] = None self.normal_signal = torch.tensor([0], device="cpu") self.end_signal = None @@ -57,10 +58,16 @@ class SimpleBuffer(KVLookupBufferBase): # tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_recver: tokens and roi of the consumer (query) - tokens_sender = tokens_roi_sender[0] - tokens_recver = tokens_roi_recver[0] - roi_sender = tokens_roi_sender[1] - roi_recver = tokens_roi_recver[1] + target_rank_sender = tokens_roi_sender[0] + target_rank_recver = tokens_roi_recver[0] + + if target_rank_sender.item() != target_rank_recver.item(): + return 0 + + tokens_sender = tokens_roi_sender[1] + tokens_recver = tokens_roi_recver[1] + roi_sender = tokens_roi_sender[2] + roi_recver = tokens_roi_recver[2] if tokens_recver is None: # consumer sends an empty request @@ -80,14 +87,14 @@ class SimpleBuffer(KVLookupBufferBase): return 0 - def _send_tensor_and_dec_size(self, - tensor: Optional[torch.Tensor]) -> None: + def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor], + target_rank: int) -> None: assert tensor is not None, "Use self.data_pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() if tensor.dtype == torch.bool: tensor = tensor.float() - self.data_pipe.send_tensor(tensor) + self.data_pipe.send_tensor(tensor, target_rank) def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): @@ -100,7 +107,7 @@ class SimpleBuffer(KVLookupBufferBase): raise AssertionError(f"Unknown data type {type(data)}") - def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def _add_to_buffer(self, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor): @@ -115,7 +122,7 @@ class SimpleBuffer(KVLookupBufferBase): if isinstance(hidden, torch.Tensor): hidden = hidden.clone() - buffer_item = [input_tokens, roi, key, value, hidden] + buffer_item = [torch.tensor(target_rank), input_tokens, roi, key, value, hidden] with self.buffer_lock: for data in buffer_item: @@ -125,53 +132,54 @@ class SimpleBuffer(KVLookupBufferBase): def _is_end_signal(self, signal): return signal is None - def drop_select_handler(self): + def drop_select_handler(self, rank: int): try: - while True: - signal = self.signal_pipe.recv_tensor() - if self._is_end_signal(signal): - logger.info("Received end signal!") - break + signal = self.signal_pipe.recv_tensor(rank) + if self._is_end_signal(signal): + logger.info("Received end signal!") + return + target_kv_rank = self.data_pipe.recv_tensor(rank) + # assert target_rank.item() == rank, "Target rank does not match"\ + # "the rank of the drop-select handler" + input_tokens = self.data_pipe.recv_tensor(rank) + roi = self.data_pipe.recv_tensor(rank) + assert roi is not None, "Please provide the roi when sending "\ + "drop-select request" + roi = (roi > 0.5) + tokens_roi_recver = [target_kv_rank, input_tokens, roi] - input_tokens = self.data_pipe.recv_tensor() + matched_length = 0 - roi = self.data_pipe.recv_tensor() - assert roi is not None, "Please provide the roi when sending "\ - "drop-select request" - roi = (roi > 0.5) - tokens_roi_recver = [input_tokens, roi] + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + with self.buffer_lock: - matched_length = 0 + for _ in range(len(self.buffer)): - # perform input tokens and roi matching - # FIXME: this matching is O(n), ideally it should be O(1) - # but this buffer size won't (and shouldn't) be too large so - # the fix is not urgent. - with self.buffer_lock: + temp_length = self._matches(self.buffer[0], + tokens_roi_recver) + if temp_length > 0: + matched_length = temp_length + break + # rotate the element we just accessed to the end + self.buffer.rotate(-1) - for _ in range(len(self.buffer)): + if matched_length > 0: + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + target_rank = matched_item[0].item() + for tensor in matched_item[1:]: + self._send_tensor_and_dec_size(tensor, rank) - temp_length = self._matches(self.buffer[0], - tokens_roi_recver) - if temp_length > 0: - matched_length = temp_length - break - # rotate the element we just accessed to the end - self.buffer.rotate(-1) - - if matched_length > 0: - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer.popleft() - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - - else: - # no match, just send None - for _ in range(5): - self.data_pipe.send_tensor(None) + else: + # no match, just send None + for _ in range(5): + self.data_pipe.send_tensor(None, rank) except RuntimeError as e: if 'Connection closed by peer' not in str(e): @@ -180,10 +188,10 @@ class SimpleBuffer(KVLookupBufferBase): logger.debug("Closing drop_select_handler") def drop_select( - self, input_tokens: Optional[torch.Tensor], + self, rank: int, kv_rank: int, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - assert self.request_handling_thread is None, \ + assert not self.request_handling_thread, \ "drop_select should be called by the KV cache consumer "\ "(e.g. the decode vLLM instance)" @@ -192,26 +200,28 @@ class SimpleBuffer(KVLookupBufferBase): if isinstance(roi, torch.Tensor): roi = roi.clone().float() - self.signal_pipe.send_tensor(self.normal_signal) - self.data_pipe.send_tensor(input_tokens) - self.data_pipe.send_tensor(roi) + self.signal_pipe.send_tensor(self.normal_signal, rank) - input_tokens = self.data_pipe.recv_tensor() - roi = self.data_pipe.recv_tensor() + self.data_pipe.send_tensor(torch.tensor(kv_rank), rank) + self.data_pipe.send_tensor(input_tokens, rank) + self.data_pipe.send_tensor(roi, rank) + + input_tokens = self.data_pipe.recv_tensor(rank) + roi = self.data_pipe.recv_tensor(rank) if roi is not None: # convert from float tensor to bool tensor # as PyNccl does not support sending bool tensor roi = (roi > 0.5) - key = self.data_pipe.recv_tensor() - value = self.data_pipe.recv_tensor() - hidden = self.data_pipe.recv_tensor() + key = self.data_pipe.recv_tensor(rank) + value = self.data_pipe.recv_tensor(rank) + hidden = self.data_pipe.recv_tensor(rank) return [input_tokens, roi, key, value, hidden] def full_handler(self): time.sleep(0.001) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: @@ -222,20 +232,19 @@ class SimpleBuffer(KVLookupBufferBase): while self.buffer_size > self.buffer_size_threshold: self.full_handler() - self._add_to_buffer(input_tokens, roi, key, value, hidden) + self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden) # when calling the insert, the current process is a sender # need to launch the request handler and start listening to request. + target_rank_global = target_rank + kv_group_rank if self.request_handling_thread is None: - self.request_handling_thread = threading.Thread( - target=self.drop_select_handler) - self.request_handling_thread.start() + self.request_handling_thread = ThreadPoolExecutor(max_workers=1) + self.request_handling_thread.submit(self.drop_select_handler, target_rank_global) def close(self): - if hasattr(self, "request_handling_thread" - ) and self.request_handling_thread is not None: - self.request_handling_thread.join() + if hasattr(self, "request_handling_thread") and self.request_handling_thread: + self.request_handling_thread.shutdown() else: # TODO: have a explicit close signal and have a explicit way to diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py index 40589fb3ef872..da2829cfcc565 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -23,7 +23,7 @@ class KVPipeBase(ABC): """ @abstractmethod - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None: """Send a tensor, or None, via the pipe. Need to support sending None -- important for error handling. @@ -41,7 +41,7 @@ class KVPipeBase(ABC): raise NotImplementedError @abstractmethod - def recv_tensor(self) -> Optional[torch.Tensor]: + def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: """Receive a tensor (can be None) from the pipeline. Returns: diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index 7aa53d07a9ef2..f5dd50b7ab86a 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -45,33 +45,33 @@ class PyNcclPipe(KVPipeBase): METADATA_DTYPE = torch.int64 def __init__(self, + kv_group_rank: int, local_rank: int, config: KVTransferConfig, device: Optional[str] = None, port_offset: int = 0): self.config = config self.local_rank = local_rank - self.kv_rank = self.config.kv_rank + self.kv_group_rank = kv_group_rank self.kv_parallel_size = self.config.kv_parallel_size + self.kv_world_size = self.config.kv_world_size if device is None: self.device = self._select_device(self.config.kv_buffer_device) else: self.device = self._select_device(device) # build distributed connection and send/recv implementation + logger.info("Creating process group for kv transfer with rank %d and world size %d, ip: %s, port: %d", self.kv_group_rank, self.kv_world_size, self.config.kv_ip, self.config.kv_port + port_offset) self.group = StatelessProcessGroup.create( host=self.config.kv_ip, port=self.config.kv_port + port_offset, - rank=self.kv_rank, - world_size=self.kv_parallel_size, + rank=self.kv_group_rank, + world_size=self.kv_world_size, ) # add a barrier to make sure the connection is initiated properly self.group.barrier() impl = self._get_device_send_recv_impl(self.group) self.device_send_func, self.device_recv_func = impl - # set target rank - self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size - self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size # transportation-related variables self.transport_thread: Optional[ThreadPoolExecutor] = None @@ -145,16 +145,16 @@ class PyNcclPipe(KVPipeBase): dtype=metadata["dtype"], device=self.device) - def _send_metadata(self, metadata: Metadata): + def _send_metadata(self, metadata: Metadata, target_rank: int): """ Send the metadata dictionary to the target rank. Parameters: - metadata: A dictionary with keys "dtype" and "shape". """ - self.group.send_obj(metadata, self.target_rank_for_send) + self.group.send_obj(metadata, target_rank) - def _recv_metadata(self) -> Metadata: + def _recv_metadata(self, src_rank: int) -> Metadata: """ Receive the metadata dictionary from the target rank. @@ -162,9 +162,9 @@ class PyNcclPipe(KVPipeBase): - metadata: A dictionary with keys "dtype" and "shape" describing the tensor. """ - return self.group.recv_obj(self.target_rank_for_recv) + return self.group.recv_obj(src_rank) - def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + def _send_impl(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: """ The actual implementation of sending the tensor and its metadata to the target rank. @@ -174,12 +174,12 @@ class PyNcclPipe(KVPipeBase): being sent. """ metadata = self._make_metadata(tensor) - self._send_metadata(metadata) + self._send_metadata(metadata, target_rank) if tensor is not None: self.device_send_func(tensor.to(self.device), - self.target_rank_for_send) + target_rank) - def _recv_impl(self) -> Optional[torch.Tensor]: + def _recv_impl(self, src_rank: int) -> Optional[torch.Tensor]: """ The actual implementation of receiving a tensor and its metadata from the target rank. @@ -187,21 +187,22 @@ class PyNcclPipe(KVPipeBase): Returns: - buffer: The received tensor, or None if no tensor is received. """ - metadata = self._recv_metadata() + metadata = self._recv_metadata(src_rank) if metadata["dtype"] is None: return None buffer = self._prepare_recv_buffer(metadata) - self.device_recv_func(buffer, self.target_rank_for_recv) + self.device_recv_func(buffer, src_rank) return buffer def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], - tensor_size: int) -> None: + tensor_size: int, + target_rank: int) -> None: """ Wrapper for _send_impl to handle exceptions and update buffer size. """ try: - self._send_impl(tensor) + self._send_impl(tensor, target_rank) with self.buffer_size_lock: self.buffer_size -= tensor_size @@ -220,7 +221,7 @@ class PyNcclPipe(KVPipeBase): logger.debug("KV cache transfer pipe is full. Waiting...") time.sleep(0.05) - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: """ Sends a tensor and its metadata to the destination rank in a non-blocking way. @@ -228,6 +229,7 @@ class PyNcclPipe(KVPipeBase): Parameters: - tensor: The tensor to send, or None if no tensor is being sent. """ + logger.debug("Rank %d sending tensor of shape %s dtype %s to rank %d", self.kv_group_rank, tensor.shape if tensor is not None else "None", tensor.dtype if tensor is not None else "None", target_rank) if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) @@ -241,32 +243,39 @@ class PyNcclPipe(KVPipeBase): with self.buffer_size_lock: self.buffer_size += tensor_size - self.transport_thread.submit(self.send_tensor_wrapper, tensor, - tensor_size) + future = self.transport_thread.submit(self.send_tensor_wrapper, tensor, + tensor_size, + target_rank) + return future - def recv_tensor(self) -> Optional[torch.Tensor]: + def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: """ Receives a tensor and its metadata from the source rank. Blocking call. Returns: - tensor: The received tensor, or None if no tensor is received. """ + + logger.debug("Rank %d receiving tensor from rank %d", self.kv_group_rank, src_rank) + if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) - future = self.transport_thread.submit(self._recv_impl) + future = self.transport_thread.submit(self._recv_impl, src_rank) - try: - tensor = future.result() - except Exception as e: - logger.error("Encountering exception in KV receiving thread") - logger.error("%s", e) - logger.error("My device: %s", self.device) - import traceback - traceback.print_exc() - raise e + return future - return tensor + # try: + # tensor = future.result() + # except Exception as e: + # logger.error("Encountering exception in KV receiving thread") + # logger.error("%s", e) + # logger.error("My device: %s", self.device) + # import traceback + # traceback.print_exc() + # raise e + + # return tensor def close(self): """ diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py index 1e80e0bd7de86..cd90206f89b36 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -35,6 +35,7 @@ class KVTransferAgent: rank: int, local_rank: int, config: "VllmConfig", + world_group, ): self.config = config @@ -47,7 +48,7 @@ class KVTransferAgent: "TransferAgent should only be used when kv_connector is set." self.connector = KVConnectorFactory.create_connector( - rank, local_rank, config) + rank, local_rank, config, world_group) def send_kv_caches_and_hidden_states( self, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 321902d11fd73..b8937ef86946d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1085,7 +1085,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, - config=vllm_config) + config=vllm_config, + world_group=get_world_group()) def ensure_model_parallel_initialized( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d82d9ad9df323..03896aa64f67d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2,13 +2,17 @@ import copy import time +import pickle +import uuid from collections import Counter as collectionsCounter from collections import deque +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass +from concurrent.futures import ThreadPoolExecutor from functools import partial from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable, - List, Mapping, NamedTuple, Optional) + List, Mapping, NamedTuple, Optional, Tuple) from typing import Sequence as GenericSequence from typing import Set, Type, Union, cast, overload @@ -60,6 +64,9 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.version import __version__ as VLLM_VERSION +from vllm.remote_prefill import RemotePrefillRequest, RemotePrefillParams, MemoryTransferRequest, MemoryOpType +from vllm.distributed.device_communicators.nixl import NixlMetadata + logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -90,7 +97,7 @@ class OutputData(NamedTuple): # outputs from multiple steps. is_first_step_output: Optional[bool] skip: List[int] - + remote_prefill_requests: Optional[List[RemotePrefillRequest]] class SchedulerContext: @@ -104,11 +111,14 @@ class SchedulerContext: self.multi_step_stream_outputs: bool = multi_step_stream_outputs + self.remote_prefill_requests: List[RemotePrefillRequest] = [] + def append_output(self, outputs: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduler_outputs: SchedulerOutputs, is_async: bool, is_last_step: bool, - is_first_step_output: Optional[bool]): + is_first_step_output: Optional[bool], + remote_prefill_requests: Optional[List[RemotePrefillRequest]] = None): self.output_queue.append( OutputData(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, @@ -116,7 +126,9 @@ class SchedulerContext: is_async=is_async, is_last_step=is_last_step, is_first_step_output=is_first_step_output, - skip=[])) + skip=[], + remote_prefill_requests=remote_prefill_requests)) + class LLMEngine: @@ -348,7 +360,7 @@ class LLMEngine: # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = [ Scheduler( - self.scheduler_config, self.cache_config, self.lora_config, + self.model_config, self.scheduler_config, self.cache_config, self.lora_config, self.parallel_config.pipeline_parallel_size, self.async_callbacks[v_id] if self.model_config.use_async_output_proc else None) @@ -405,6 +417,40 @@ class LLMEngine: self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} + self.engine_id = str(uuid.uuid4()) + self._nixl_agents_names: Optional[List[str]] = None + if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": + self._nixl_agents_names = self._initialize_nixl() + + self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size) + self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size) + self._finished_prefills = set() + self._finished_transfers = set() + + @property + def is_nixl_initialized(self) -> bool: + return getattr(self, "_nixl_agents_names", None) is not None + + def get_nixl_metadata(self) -> NixlMetadata: + if not self.is_nixl_initialized: + raise RuntimeError("Nixl is not initialized") + agent_metadata = self.model_executor.collective_rpc("get_nixl_agent_metadata") + kv_caches_base_addr = self.model_executor.collective_rpc("get_nixl_kv_caches_base_addr") + return NixlMetadata(engine_id=self.engine_id, agent_metadata=agent_metadata, kv_caches_base_addr=kv_caches_base_addr, num_blocks=self.cache_config.num_gpu_blocks) + + def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata) -> List[str]: + if not self.is_nixl_initialized: + raise RuntimeError("Nixl is not initialized") + engine_id = nixl_metadata.engine_id + agents_metadata = nixl_metadata.agent_metadata + kv_caches_base_addr = nixl_metadata.kv_caches_base_addr + num_blocks = nixl_metadata.num_blocks + return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr, num_blocks)) + + def _initialize_nixl(self) -> List[bytes]: + agents_names = self.model_executor.collective_rpc("initialize_nixl", args=(self.engine_id,)) + return agents_names + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -500,6 +546,8 @@ class LLMEngine: # Shutdown model executor when engine is garbage collected # Use getattr since __init__ can fail before the field is set if model_executor := getattr(self, "model_executor", None): + if self.is_nixl_initialized: + model_executor.collective_rpc("shutdown_nixl") model_executor.shutdown() def get_tokenizer_group( @@ -552,11 +600,14 @@ class LLMEngine: prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, ) -> Optional[SequenceGroup]: """Add a processed request to the engine's request pool. return the created sequence group. """ if isinstance(params, SamplingParams) and params.n > 1: + if remote_prefill_params is not None: + raise ValueError("Remote prefill params are not supported for multi-step sampling") ParallelSampleSequenceGroup.add_request( request_id, self, @@ -574,6 +625,8 @@ class LLMEngine: # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) + if remote_prefill_params is not None and remote_prefill_params.is_remote_decode: + next(self.seq_counter) # empty sequence for staging eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) if is_encoder_decoder_inputs(processed_inputs): @@ -584,7 +637,7 @@ class LLMEngine: encoder_inputs = None seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, - lora_request, prompt_adapter_request) + lora_request, prompt_adapter_request, remote_prefill_params) encoder_seq = (None if encoder_inputs is None else Sequence( seq_id, encoder_inputs, block_size, eos_token_id, lora_request, @@ -601,8 +654,12 @@ class LLMEngine: trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, - priority=priority) + priority=priority, + remote_prefill_params=remote_prefill_params, + ) elif isinstance(params, PoolingParams): + if remote_prefill_params is not None: + raise ValueError("Remote prefill params are not supported for pooling") seq_group = self._create_sequence_group_with_pooling( request_id, seq, @@ -673,6 +730,7 @@ class LLMEngine: trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: @@ -765,6 +823,7 @@ class LLMEngine: prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, priority=priority, + remote_prefill_params=remote_prefill_params, ) def _validate_token_prompt(self, prompt: PromptType, @@ -799,6 +858,7 @@ class LLMEngine: prompt_adapter_request: Optional[PromptAdapterRequest] = None, encoder_seq: Optional[Sequence] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -829,7 +889,9 @@ class LLMEngine: trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, - priority=priority) + priority=priority, + remote_prefill_params=remote_prefill_params + ) return seq_group @@ -995,11 +1057,11 @@ class LLMEngine: # When we process only one request, no pop is required # (since later we will process all of the rest) (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, is_first_step_output, skip) = ctx.output_queue[0] + is_last_step, is_first_step_output, skip, remote_prefill_requests) = ctx.output_queue[0] else: (outputs, seq_group_metadata_list, scheduler_outputs, is_async, is_last_step, is_first_step_output, - skip) = ctx.output_queue.popleft() + skip, remote_prefill_requests) = ctx.output_queue.popleft() # Sanity check assert len(seq_group_metadata_list) == len( @@ -1325,15 +1387,55 @@ class LLMEngine: # Clear outputs for each new scheduler iteration ctx.request_outputs.clear() + ctx.remote_prefill_requests.clear() # Skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. + remote_prefill_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + running_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + remote_prefill_scheduled_seq_groups: List[ScheduledSequenceGroup] = [] + running_scheduled_seq_groups: List[ScheduledSequenceGroup] = [] + if not self._has_remaining_steps(seq_group_metadata_list): - # Schedule iteration + (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc - ) = self.scheduler[virtual_engine].schedule() + ) = self.scheduler[virtual_engine].schedule(self._finished_prefills, self._finished_transfers) + + + # Separate remote prefill and running seq groups + for seq_group_metadata, scheduled_seq_group in zip(seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups): + if seq_group_metadata.do_remote_prefill: + remote_prefill_seq_group_metadata_list.append(seq_group_metadata) + remote_prefill_scheduled_seq_groups.append(scheduled_seq_group) + else: + running_seq_group_metadata_list.append(seq_group_metadata) + running_scheduled_seq_groups.append(scheduled_seq_group) + + seq_group_metadata_list = running_seq_group_metadata_list + scheduler_outputs.scheduled_seq_groups = running_scheduled_seq_groups + + # Send remote prefill requests before model execution + for seq_group_metadata, scheduled_seq_group in zip(remote_prefill_seq_group_metadata_list, remote_prefill_scheduled_seq_groups): + assert len(scheduled_seq_group.seq_group.seqs) == 1 + assert self._nixl_agents_names + seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id + block_table = seq_group_metadata.block_tables[seq_id] + if len(block_table) == len(seq_group_metadata.computed_block_nums): + logger.debug("No blocks to prefill") + self._finished_prefills.add(seq_group_metadata.request_id) + continue + remote_prefill_request = RemotePrefillRequest( + request_id=seq_group_metadata.request_id, + # prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids[:-1], # last one will be decoded on decode for sampling anyway + prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids, # TODO ptarasiewicz do not send the last token when NIXL fixes send notif (needed for writing 0 blocks) + sampling_params=scheduled_seq_group.seq_group.sampling_params, + block_ids=block_table, + engine_id=self.engine_id, + computed_block_ids=seq_group_metadata.computed_block_nums, + ) + scheduled_seq_group.seq_group.remote_prefill_params.remote_prefill_request_callback(remote_prefill_request) ctx.seq_group_metadata_list = seq_group_metadata_list ctx.scheduler_outputs = scheduler_outputs @@ -1383,9 +1485,46 @@ class LLMEngine: execute_model_req.async_callback = self.async_callbacks[ virtual_engine] - outputs = self.model_executor.execute_model( - execute_model_req=execute_model_req) + # After model execution, we need to transfer the memory from the prefill to the decode + memory_transfer_reqs = [] + for scheduled_seq_group, seq_group_metadata in zip(scheduler_outputs.scheduled_seq_groups, seq_group_metadata_list): + remote_prefill_params = scheduled_seq_group.seq_group.remote_prefill_params + if remote_prefill_params is not None and remote_prefill_params.is_remote_decode: + assert len(scheduled_seq_group.seq_group.seqs) == 1 + req_id = scheduled_seq_group.seq_group.request_id + seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id + block_table = seq_group_metadata.block_tables[seq_id] + staging_block_ids = seq_group_metadata.block_tables[seq_id + 1] + num_computed_blocks = len(seq_group_metadata.computed_block_nums) + computed_decode_block_ids = remote_prefill_params.decode_block_ids[:num_computed_blocks] + + if computed_decode_block_ids: + kv_recv_req = MemoryTransferRequest( + request_id=req_id, + local_block_ids=block_table[:num_computed_blocks], + staging_block_ids=staging_block_ids[:num_computed_blocks], + remote_block_ids=computed_decode_block_ids, + remote_engine_id=remote_prefill_params.decode_engine_id, + notify_msg=req_id, + op_type=MemoryOpType.READ + ) + memory_transfer_reqs.append(kv_recv_req) + + kv_send_req = MemoryTransferRequest( + request_id=req_id, + local_block_ids=block_table[num_computed_blocks:], + staging_block_ids=staging_block_ids[num_computed_blocks:], + remote_block_ids=remote_prefill_params.decode_block_ids[num_computed_blocks:], + remote_engine_id=remote_prefill_params.decode_engine_id, + notify_msg=req_id, + op_type=MemoryOpType.WRITE + ) + memory_transfer_reqs.append(kv_send_req) + execute_model_req.memory_transfer_requests = memory_transfer_reqs + + outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model( + execute_model_req=execute_model_req) # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. if self.scheduler_config.is_multi_step: @@ -1396,7 +1535,26 @@ class LLMEngine: if len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) # No outputs in this case - outputs = [] + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=[], + blocks_to_swap_in=[], + blocks_to_swap_out=[], + blocks_to_copy=[]) + + outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model( + execute_model_req=execute_model_req) + + for req_id, notif_count in request_notif_counter.items(): + self._request_notif_counter[req_id] += notif_count + if self._request_notif_counter[req_id] > -1: + self._finished_prefills.add(req_id) + del self._request_notif_counter[req_id] + + for req_id, done_count in request_done_counter.items(): + self._request_done_counter[req_id] += done_count + if self._request_done_counter[req_id] > -1: + self._finished_transfers.add(req_id) + del self._request_done_counter[req_id] # Finish the current step for all the sequence groups. if self.scheduler_config.is_multi_step: @@ -1456,7 +1614,7 @@ class LLMEngine: # queued control plane messages, such as add/remove lora adapters. logger.debug("Stopping remote worker execution loop.") self.model_executor.stop_remote_worker_execution_loop() - + return ctx.request_outputs def _has_remaining_steps( diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 3cf1850ee65ad..ae00657987774 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -14,13 +14,17 @@ from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.utils import deprecate_kwargs - +from vllm.remote_prefill import RemotePrefillParams +from vllm.distributed.device_communicators.nixl import NixlMetadata VLLM_RPC_SUCCESS_STR = "SUCCESS" IPC_INPUT_EXT = "_input_socket" IPC_OUTPUT_EXT = "_output_socket" IPC_HEALTH_EXT = "_health_socket" IPC_DATA_EXT = "_data_socket" +IPC_REMOTE_PREFILL_REQUEST_EXT = "_remote_prefill_request_socket" +IPC_REMOTE_NIXL_METADATA_EXT = "_remote_nixl_metadata_socket" +IPC_METRICS_EXT = "_metrics_socket" class MQEngineDeadError(RuntimeError): @@ -36,6 +40,7 @@ class RPCProcessRequest: trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None priority: int = 0 + remote_prefill_params: Optional[RemotePrefillParams] = None @overload def __init__( @@ -78,6 +83,7 @@ class RPCProcessRequest: trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: @@ -95,7 +101,7 @@ class RPCProcessRequest: self.trace_headers = trace_headers self.prompt_adapter_request = prompt_adapter_request self.priority = priority - + self.remote_prefill_params = remote_prefill_params @dataclass class RPCError: @@ -116,7 +122,7 @@ class RPCStartupRequest(Enum): @dataclass class RPCStartupResponse: tracing_enabled: bool - + nixl_metadata: Optional[bytes] = None class RPCUProfileRequest(Enum): START_PROFILE = 1 @@ -157,3 +163,13 @@ def ENGINE_DEAD_ERROR( return MQEngineDeadError( "Engine loop is not running. Inspect the stacktrace to " f"find the original error: {repr(error)}.") + +@dataclass +class KvMetrics: + request_active_slots: int + request_total_slots: int + kv_active_blocks: int + kv_total_blocks: int + num_requests_waiting: int + gpu_cache_usage_perc: float + gpu_prefix_cache_hit_rate: float diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 85b5f31e3a4aa..050302924ea46 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, Optional, Union, cast, overload) import cloudpickle +import msgspec import psutil import zmq import zmq.asyncio @@ -19,20 +20,23 @@ from vllm import PoolingParams from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.metrics import Stats # yapf conflicts with isort for this block # yapf: disable from vllm.engine.async_llm_engine import ( build_guided_decoding_logits_processor_async) from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, RPC_REQUEST_T, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + IPC_OUTPUT_EXT, IPC_REMOTE_PREFILL_REQUEST_EXT, + RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, IPC_REMOTE_NIXL_METADATA_EXT, RPCAbortRequest, + IPC_METRICS_EXT, RPCAdapterLoadedResponse, RPCError, RPCLoadAdapterRequest, RPCProcessRequest, RPCResetPrefixCacheRequest, RPCStartupRequest, RPCStartupResponse, - RPCUProfileRequest) + RPCUProfileRequest, KvMetrics) from vllm.engine.protocol import EngineClient # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT @@ -46,6 +50,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import deprecate_kwargs +from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback +from vllm.distributed.device_communicators.nixl import NixlMetadata logger = init_logger(__name__) @@ -91,6 +97,7 @@ class MQLLMEngineClient(EngineClient): self._errored_with: Optional[BaseException] = None # Get the configs. + self.vllm_config = engine_config self.model_config = engine_config.model_config self.decoding_config = engine_config.decoding_config @@ -115,6 +122,10 @@ class MQLLMEngineClient(EngineClient): self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + # Metrics. + self.metrics_socket: Socket = self.context.socket(zmq.constants.PULL) + self.metrics_socket.connect(f"{ipc_path}{IPC_METRICS_EXT}") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -129,8 +140,27 @@ class MQLLMEngineClient(EngineClient): # Loop to check health of the LLMEngine periodically. # Started after the MQLLMEngine is ready. self.health_loop: Optional[asyncio.Task] = None + + # Loop to check metrics of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.metrics_loop: Optional[asyncio.Task] = None + self.metrics_publisher = None + self._engine_process = psutil.Process(engine_pid) + self.nixl_metadata: Optional[NixlMetadata] = None + self.remote_prefill_request_socket: Socket = self.context.socket(zmq.constants.PULL) + self.remote_nixl_metadata_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.remote_prefill_requests_callback: Dict[str, RemotePrefillRequestCallback] = {} + if self.using_nixl_connector: + self.remote_prefill_request_socket.connect(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}") + self.remote_nixl_metadata_socket.connect(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}") + + + @property + def using_nixl_connector(self) -> bool: + return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector" + @staticmethod def is_unsupported_config(engine_args: AsyncEngineArgs): # Pipeline parallel not yet supported @@ -180,6 +210,61 @@ class MQLLMEngineClient(EngineClient): except Exception as e: self._set_errored(e) + async def run_remote_prefill_request_handler_loop(self): + try: + while True: + if await self.remote_prefill_request_socket.poll(timeout=VLLM_RPC_TIMEOUT): + frames = await self.remote_prefill_request_socket.recv(copy=False) + remote_prefill_request = msgspec.msgpack.decode(frames.buffer, type=RemotePrefillRequest) + await self.remote_prefill_requests_callback[remote_prefill_request.request_id](remote_prefill_request) + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient remote prefill request handler loop.") + + async def run_metrics_loop(self, timeout: int): + """Background loop that continually checks to ensure the engine process + is still alive. + """ + try: + while True: + # Check if the engine process is running: + if not self._engine_process.is_running() or ( + self._engine_process.status() == psutil.STATUS_ZOMBIE): + # NB: is_running() returns True for zombies + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) " + "died.")) + break + + if await self.metrics_socket.poll(timeout=timeout): + # Metrics received- check the message + message: Frame = await self.metrics_socket.recv(copy=False) + metrics = pickle.loads(message.buffer) + if self.metrics_publisher is not None and isinstance( + metrics, KvMetrics + ): + self.metrics_publisher.publish(metrics.request_active_slots, + metrics.request_total_slots, + metrics.kv_active_blocks, + metrics.kv_total_blocks, + metrics.num_requests_waiting, + metrics.gpu_cache_usage_perc, + metrics.gpu_prefix_cache_hit_rate) + logger.debug("Metrics successful.") + + # TODO: Investigate sending whole stats object + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check metrics loop.") + + except psutil.NoSuchProcess: + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) died.")) + + except Exception as e: + self._set_errored(e) + async def run_output_handler_loop(self): """Get RequestOutputs from Engine and stream to Request Queues""" @@ -278,12 +363,26 @@ class MQLLMEngineClient(EngineClient): # Wait until server is ready. response = await self._wait_for_server_rpc(socket) + if response.nixl_metadata is not None: + assert self.using_nixl_connector + self.nixl_metadata = msgspec.msgpack.decode(response.nixl_metadata, type=NixlMetadata) + self.tracing_flag = response.tracing_enabled # Start health_loop. if self.health_loop is None: self.health_loop = asyncio.create_task( self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) + + if self.using_nixl_connector: + self.remote_prefill_loop = asyncio.create_task( + self.run_remote_prefill_request_handler_loop()) + + # Start metrics_loop. + if self.metrics_loop is None: + self.metrics_loop = asyncio.create_task( + self.run_metrics_loop(timeout=VLLM_RPC_TIMEOUT)) + def close(self): """Destroy the ZeroMQ Context.""" @@ -293,6 +392,8 @@ class MQLLMEngineClient(EngineClient): # Cancel background tasks. if self.health_loop is not None: self.health_loop.cancel() + if self.metrics_loop is not None: + self.metrics_loop.cancel() if self.output_loop is not None: self.output_loop.cancel() @@ -415,6 +516,9 @@ class MQLLMEngineClient(EngineClient): """ if self._errored_with is not None: raise self._errored_with + + async def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata): + await self.remote_nixl_metadata_socket.send(msgspec.msgpack.encode(nixl_metadata), copy=False) @property def is_running(self) -> bool: @@ -473,6 +577,7 @@ class MQLLMEngineClient(EngineClient): trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, *, inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[RequestOutput, None]: @@ -502,7 +607,8 @@ class MQLLMEngineClient(EngineClient): return self._process_request(prompt, sampling_params, request_id, lora_request, trace_headers, - prompt_adapter_request, priority) + prompt_adapter_request, priority, + remote_prefill_params) @overload def encode( @@ -586,6 +692,7 @@ class MQLLMEngineClient(EngineClient): trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ PoolingRequestOutput, None]]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" @@ -630,6 +737,12 @@ class MQLLMEngineClient(EngineClient): else: lp_bytes = None + if remote_prefill_params is not None: + self.remote_prefill_requests_callback[request_id] = remote_prefill_params.remote_prefill_request_callback + remote_prefill_params.remote_prefill_request_callback = None + else: + remote_prefill_request_callback = None + request_bytes = pickle.dumps( RPCProcessRequest( prompt=prompt, @@ -639,11 +752,11 @@ class MQLLMEngineClient(EngineClient): trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, + remote_prefill_params=remote_prefill_params, )) # 3) Send the RPCGenerateRequest to the MQLLMEngine. - parts = (request_bytes, - lp_bytes) if lp_bytes else (request_bytes, ) + parts = (request_bytes, lp_bytes) if lp_bytes else (request_bytes,) await self.input_socket.send_multipart(parts, copy=False) # 4) Stream the RequestOutputs from the output queue. Note @@ -705,3 +818,6 @@ class MQLLMEngineClient(EngineClient): # Raise on error, otherwise happily return None if isinstance(request_output, BaseException): raise request_output + + def set_metrics_publisher(self, metrics_publisher): + self.metrics_publisher = metrics_publisher diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index a0dd79586588e..c82bc15b69153 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -3,35 +3,115 @@ import pickle import signal from contextlib import contextmanager -from typing import Iterator, List, Optional, Union +from typing import Iterator, List, Optional, Union, Dict import cloudpickle +import time import zmq - +import msgspec from vllm import AsyncEngineArgs, SamplingParams from vllm.engine.llm_engine import LLMEngine # yapf conflicts with isort for this block # yapf: disable from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, IPC_REMOTE_PREFILL_REQUEST_EXT, + RPCAbortRequest, + IPC_OUTPUT_EXT, IPC_METRICS_EXT, RPCAdapterLoadedResponse, RPCError, RPCLoadAdapterRequest, RPCProcessRequest, RPCResetPrefixCacheRequest, RPCStartupRequest, RPCStartupResponse, - RPCUProfileRequest) + RPCUProfileRequest, IPC_REMOTE_NIXL_METADATA_EXT, + KvMetrics) # yapf: enable from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext +from vllm.remote_prefill import RemotePrefillRequest +from vllm.distributed.device_communicators.nixl import NixlMetadata + +from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo +from dataclasses import dataclass, field logger = init_logger(__name__) POLLING_TIMEOUT_MS = 10000 HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) +class KvStatLogger(StatLoggerBase): + def __init__( + self, + max_num_seqs: int, + num_total_gpu_blocks: int, + metrics_socket + ): + # Must query initialized scheduler for max infos + self.request_total_slots = max_num_seqs + self.kv_total_blocks = num_total_gpu_blocks + self.metrics_socket = metrics_socket + + # KV metrics + self._send_kv_metrics(0, 0, 0, 0.0, 0.0) + + def log(self, stats: Stats) -> None: + self._send_kv_metrics( + stats.num_running_sys, + int(stats.gpu_cache_usage_sys * self.kv_total_blocks), + stats.num_waiting_sys, + stats.gpu_cache_usage_sys, + stats.gpu_prefix_cache_hit_rate + ) + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + pass + + def _send_kv_metrics( + self, + active_slots, + active_kv_blocks, + num_requests_waiting, + gpu_cache_usage_perc, + gpu_prefix_cache_hit_rate, + ): + if not self.metrics_socket.closed: + metrics_bytes = pickle.dumps( + KvMetrics( + active_slots, + self.request_total_slots, + active_kv_blocks, + self.kv_total_blocks, + num_requests_waiting, + gpu_cache_usage_perc, + gpu_prefix_cache_hit_rate, + ) + ) + self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) + +# TODO: Send entire stats object to the client +# class StatLogger(StatLoggerBase): +# def __init__( +# self, +# metrics_socket +# ): +# self.metrics_socket = metrics_socket + +# def log(self, stats: Stats) -> None: +# self._send_metrics(stats) + +# def info(self, type: str, obj: SupportsMetricsInfo) -> None: +# pass + +# def _send_metrics(self, stats: Stats): +# if not self.metrics_socket.closed: +# metrics_bytes = pickle.dumps(stats) +# self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) + + + + class MQLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. @@ -94,12 +174,37 @@ class MQLLMEngine: self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + # Send metrics back to client. + self.metrics_socket = self.ctx.socket(zmq.constants.PUSH) + self.metrics_socket.bind(f"{ipc_path}{IPC_METRICS_EXT}") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" # Error state. self._errored_with: Optional[BaseException] = None + self.remote_prefill_request_socket = self.ctx.socket(zmq.constants.PUSH) + self.remote_nixl_metadata_socket = self.ctx.socket(zmq.constants.PULL) + if self.engine.is_nixl_initialized: + self.remote_prefill_request_socket.bind(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}") + self.remote_nixl_metadata_socket.bind(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}") + + + # Attach logger for continuous metrics publishing + self.kv_stat_logger = KvStatLogger( + self.engine.scheduler_config.max_num_seqs, + self.engine.cache_config.num_gpu_blocks, + self.metrics_socket + ) + self.engine.add_logger("kv_metrics", self.kv_stat_logger) + + # TODO investigate sending whole stats object + # self.general_stat_logger = StatLogger( + # self.metrics_socket + # ) + # self.engine.add_logger("general_metrics", self.general_stat_logger) + @property def dead_error(self) -> BaseException: if self._errored_with is not None: @@ -171,8 +276,17 @@ class MQLLMEngine: # Handle the query from the Client. if request == RPCStartupRequest.IS_SERVER_READY: tracing_enabled = self.engine.is_tracing_enabled() - response = RPCStartupResponse( - tracing_enabled=tracing_enabled) + + # Send nixl metadata to the client + if self.engine.is_nixl_initialized: + nixl_metadata = self.engine.get_nixl_metadata() + encoded_nixl_metadata = msgspec.msgpack.encode(nixl_metadata) + response = RPCStartupResponse( + tracing_enabled=tracing_enabled, + nixl_metadata=encoded_nixl_metadata) + else: + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) except Exception as e: response = e @@ -185,6 +299,7 @@ class MQLLMEngine: while True: if not self.engine.has_unfinished_requests(): + logger.debug("No unfinished requests") # Poll until there is work to do. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: # When there's no work, check on engine health and send @@ -220,6 +335,13 @@ class MQLLMEngine: def handle_new_input(self): """Handle new input from the socket""" try: + if self.engine.is_nixl_initialized: + while self.remote_nixl_metadata_socket.poll(timeout=0) != 0: + frames = self.remote_nixl_metadata_socket.recv(copy=False) + nixl_metadata = msgspec.msgpack.decode(frames.buffer, type=NixlMetadata) + logger.debug("Adding remote nixl metadata for engine: %s", nixl_metadata.engine_id) + self.engine.add_remote_nixl_metadata(nixl_metadata) + while self.input_socket.poll(timeout=0) != 0: frames = self.input_socket.recv_multipart(copy=False) request = pickle.loads(frames[0].buffer) @@ -262,6 +384,11 @@ class MQLLMEngine: self._send_outputs(rpc_err) try: + if request.remote_prefill_params is not None and request.remote_prefill_params.is_remote_prefill: + def remote_prefill_request_callback(request: RemotePrefillRequest): + logger.debug("Sending remote prefill request: %s", request.request_id) + self.remote_prefill_request_socket.send(msgspec.msgpack.encode(request), copy=False) + request.remote_prefill_params.remote_prefill_request_callback = remote_prefill_request_callback self.engine.add_request( request_id=request_id, prompt=request.prompt, @@ -269,7 +396,9 @@ class MQLLMEngine: lora_request=request.lora_request, trace_headers=request.trace_headers, prompt_adapter_request=request.prompt_adapter_request, - priority=request.priority) + priority=request.priority, + remote_prefill_params=request.remote_prefill_params, + ) if self.log_requests: logger.info("Added request %s.", request.request_id) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 107220d548afc..c716f75f721b6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -34,6 +34,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls +from vllm.remote_prefill import RemotePrefillParams logger = init_logger(__name__) @@ -112,6 +113,7 @@ class OpenAIServingChat(OpenAIServing): self, request: ChatCompletionRequest, raw_request: Optional[Request] = None, + remote_prefill_params: Optional[RemotePrefillParams] = None, ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, ErrorResponse]: """ @@ -243,6 +245,7 @@ class OpenAIServingChat(OpenAIServing): trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=request.priority, + remote_prefill_params=remote_prefill_params, ) generators.append(generator) diff --git a/vllm/envs.py b/vllm/envs.py index 745b068b7a458..0ae63d9b76f44 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -87,6 +87,10 @@ if TYPE_CHECKING: VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" + VLLM_KV_CAPI_PATH: Optional[str] = None + VLLM_KV_NAMESPACE: Optional[str] = None + VLLM_KV_COMPONENT: Optional[str] = None + VLLM_WORKER_ID: Optional[int] = None def get_default_cache_root(): @@ -572,6 +576,21 @@ environment_variables: Dict[str, Callable[[], Any]] = { # models the alignment is already naturally aligned to 256 bytes. "VLLM_CUDA_MEM_ALIGN_KV_CACHE": lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))), + + # Path to the C API Library + "VLLM_KV_CAPI_PATH": + lambda: os.environ.get("VLLM_KV_CAPI_PATH", None), + + # Identifiers to publish KV related information + "VLLM_KV_NAMESPACE": + lambda: os.environ.get("VLLM_KV_NAMESPACE", None), + "VLLM_KV_COMPONENT": + lambda: os.environ.get("VLLM_KV_COMPONENT", None), + + # Worker ID used for identifying workers in distributed settings + "VLLM_WORKER_ID": + lambda: int(os.getenv("VLLM_WORKER_ID", "0")) + if "VLLM_WORKER_ID" in os.environ else None, } # end-env-vars-definition diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 773f5abe71dae..3eefd266f0551 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -585,6 +585,8 @@ class DeepseekV2Model(nn.Module): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + self.config = config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size diff --git a/vllm/outputs.py b/vllm/outputs.py index 786380c37f6cb..56a7cf8947cb6 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -6,16 +6,16 @@ from typing import Dict, Generic, List, MutableSequence, Optional from typing import Sequence as GenericSequence from typing import Union +import msgspec import torch from typing_extensions import TypeVar, deprecated from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict -from vllm.sampling_params import RequestOutputKind +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceGroupBase, SequenceStatus) - @dataclass class CompletionOutput: """The output data of one completion output of a request. diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 97f9e21295731..1bb97b006aa5a 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -83,7 +83,7 @@ class RequestOutputKind(Enum): DELTA = 1 # Do not return intermediate RequestOuputs FINAL_ONLY = 2 - + class SamplingParams( msgspec.Struct, diff --git a/vllm/sequence.py b/vllm/sequence.py index 534b9e60610a2..18675d2fcb56c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -20,6 +20,7 @@ from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.remote_prefill import RemotePrefillParams, MemoryTransferRequest VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -59,13 +60,14 @@ class SequenceStatus(enum.IntEnum): """Status of a sequence.""" WAITING = 0 RUNNING = 1 - SWAPPED = 2 - # Note: anything after SWAPPED (2) will be considered + REMOTE_PREFILLING = 2 + SWAPPED = 3 + # Note: anything after SWAPPED (3) will be considered # as a finished status. - FINISHED_STOPPED = 3 - FINISHED_LENGTH_CAPPED = 4 - FINISHED_ABORTED = 5 - FINISHED_IGNORED = 6 + FINISHED_STOPPED = 4 + FINISHED_LENGTH_CAPPED = 5 + FINISHED_ABORTED = 6 + FINISHED_IGNORED = 7 @staticmethod def is_finished(status: "SequenceStatus") -> bool: @@ -409,6 +411,7 @@ class Sequence: eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + remote_prefill_params: Optional[RemotePrefillParams] = None, ) -> None: self.seq_id = seq_id self.inputs = SingletonInputsAdapter(inputs) @@ -416,7 +419,7 @@ class Sequence: self.eos_token_id = eos_token_id self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - + self.remote_prefill_params = remote_prefill_params self.data = SequenceData.from_seqs(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -639,6 +642,7 @@ class SequenceGroup: trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request. priority: User-defined priority of the request. + remote_prefill_params: Remote prefill parameters. """ def __init__( @@ -654,6 +658,7 @@ class SequenceGroup: trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, ) -> None: self.request_id = request_id self.seqs = seqs @@ -678,7 +683,7 @@ class SequenceGroup: self.encoder_seq = encoder_seq self.trace_headers = trace_headers self.priority = priority - + self.remote_prefill_params = remote_prefill_params self.cached_request_output = None @property @@ -927,6 +932,9 @@ class SequenceGroupMetadata( query tokens for prefill, we don't need sampling. token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. + do_remote_prefill: True if remote prefill is required. + do_remote_decode: True if remote decode is required. + decode_memory_desc: The memory descriptor for the decoder blocks. lora_request: LoRA request. computed_block_nums: The block numbers that are already computed, used in prefix caching. @@ -966,6 +974,9 @@ class SequenceGroupMetadata( cross_block_table: Optional[List[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None token_chunk_size: Optional[int] = None + do_remote_prefill: bool = False + do_remote_decode: bool = False + decode_memory_desc: Optional[bytes] = None ### Stateful fields that are lazily defined. ### # The number of speculative tokens adopted in this request. @@ -1310,6 +1321,8 @@ class ExecuteModelRequest( last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback async_callback: Optional[Callable] = None + # The memory transfer requests. + memory_transfer_requests: Optional[List[MemoryTransferRequest]] = None @property def is_first_multi_step(self) -> bool: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 12baecde6e42c..a3f2c464d0fd1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1824,6 +1824,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): if self.vllm_config.kv_transfer_config is None: return False + + if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": + return False prefill_meta = model_input.attn_metadata.prefill_metadata @@ -1849,6 +1852,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): if self.vllm_config.kv_transfer_config is None: return False + + if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": + return False prefill_meta = model_input.attn_metadata.prefill_metadata diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 582aa460eb4fa..876329d6f8814 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -2,7 +2,7 @@ """A GPU worker class.""" import gc import os -from typing import Dict, List, Optional, Set, Tuple, Type, Union +from typing import Dict, List, Optional, Set, Tuple, Type, Union, TYPE_CHECKING, Any import torch import torch.distributed @@ -31,6 +31,9 @@ from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.pooling_model_runner import PoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) +from vllm.distributed.device_communicators.nixl import DynamoNixlConnector +from vllm.remote_prefill import MemoryOpType + logger = init_logger(__name__) @@ -306,6 +309,46 @@ class Worker(LocalOrDistributedWorkerBase): self._init_cache_engine() self._warm_up_model() + def initialize_nixl(self, engine_id: str) -> List[bytes]: + + # TODO ptarasiewicz nixl can also support DRAM + assert self.device_config.device_type == "cuda", "Currently only CUDA is supported for Nixl connector" + + self.nixl_connector = DynamoNixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank? + assert len(self.cache_engine) == 1, "Only one cache engine is supported for now" + self.nixl_connector.register_kv_caches(self.cache_engine[0].gpu_cache) + return self.nixl_connector.agent_name + + def get_nixl_agent_metadata(self) -> bytes: + assert self.nixl_connector is not None, "Nixl connector is not initialized" + return self.nixl_connector.get_agent_metadata() + + def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]], num_blocks: int) -> str: + assert self.nixl_connector is not None, "Nixl connector is not initialized" + agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata), kv_caches_base_addr, num_blocks) # TODO ptarasiewicz: rank or local_rank? + return agent_name + + def get_nixl_kv_caches_base_addr(self) -> List[bytes]: + assert self.nixl_connector is not None, "Nixl connector is not initialized" + return self.nixl_connector.kv_caches_base_addr[self.nixl_connector.engine_id] + + def _read_blocks(self, worker_input: WorkerInput) -> None: + for i, op_type in enumerate(worker_input.op_type): + if op_type == MemoryOpType.READ: + self.nixl_connector.read_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i]) + + def _write_blocks(self, worker_input: WorkerInput) -> None: + if not self.is_driver_worker: + torch.cuda.synchronize() # to make sure that the blocks are ready, on driver worker we transfer after sampling, so there's no need to synchronize + + for i, op_type in enumerate(worker_input.op_type): + if op_type == MemoryOpType.WRITE: + self.nixl_connector.write_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i], worker_input.notify_msg[i]) + + def shutdown_nixl(self) -> None: + assert self.nixl_connector is not None, "Nixl connector is not initialized" + self.nixl_connector.shutdown() + def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ @@ -367,6 +410,8 @@ class Worker(LocalOrDistributedWorkerBase): blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device=self.device, dtype=torch.int64).view(-1, 2) + + mem_transfer_reqs = execute_model_req.memory_transfer_requests or [] return WorkerInput( num_seq_groups=num_seq_groups, @@ -375,6 +420,12 @@ class Worker(LocalOrDistributedWorkerBase): blocks_to_copy=blocks_to_copy, virtual_engine=virtual_engine, num_steps=num_steps, + local_block_ids=[r.local_block_ids for r in mem_transfer_reqs], + staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs], + remote_block_ids=[r.remote_block_ids for r in mem_transfer_reqs], + remote_engine_id=[r.remote_engine_id for r in mem_transfer_reqs], + notify_msg=[r.notify_msg for r in mem_transfer_reqs], + op_type=[r.op_type for r in mem_transfer_reqs], ) @torch.inference_mode() diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 819b81fbfdbb2..2891854b17e58 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import cloudpickle import torch import torch.nn as nn +from collections import defaultdict from vllm.config import (ObservabilityConfig, VllmConfig, set_current_vllm_config) @@ -23,6 +24,9 @@ from vllm.utils import (enable_trace_function_call_for_thread, from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) +from vllm.distributed.device_communicators.nixl import DynamoNixlConnector +from vllm.remote_prefill import MemoryOpType + logger = init_logger(__name__) @@ -53,6 +57,8 @@ class WorkerBase(ABC): from vllm.platforms import current_platform self.current_platform = current_platform + self.nixl_connector: Optional[DynamoNixlConnector] = None + @abstractmethod def init_device(self) -> None: """Initialize device state, such as loading the model or other on-device @@ -216,6 +222,13 @@ class WorkerInput: virtual_engine: int = 0 num_steps: int = 1 + local_block_ids: Optional[List[List[int]]] = None + staging_block_ids: Optional[List[List[int]]] = None + remote_block_ids: Optional[List[List[int]]] = None + remote_engine_id: Optional[List[str]] = None + notify_msg: Optional[List[str]] = None + op_type: Optional[List[MemoryOpType]] = None + @classmethod def from_broadcasted_tensor_dict( cls: Type["WorkerInput"], @@ -232,6 +245,12 @@ class WorkerInput: blocks_to_copy=tensor_dict.pop("blocks_to_copy"), virtual_engine=tensor_dict["virtual_engine"], num_steps=tensor_dict.pop("num_steps"), + local_block_ids=tensor_dict.pop("local_block_ids"), + staging_block_ids=tensor_dict.pop("staging_block_ids"), + remote_block_ids=tensor_dict.pop("remote_block_ids"), + remote_engine_id=tensor_dict.pop("remote_engine_id"), + notify_msg=tensor_dict.pop("notify_msg"), + op_type=tensor_dict.pop("op_type"), ) def as_broadcastable_tensor_dict( @@ -246,6 +265,12 @@ class WorkerInput: "blocks_to_copy": self.blocks_to_copy, "virtual_engine": self.virtual_engine, "num_steps": self.num_steps, + "local_block_ids": self.local_block_ids, + "staging_block_ids": self.staging_block_ids, + "remote_block_ids": self.remote_block_ids, + "remote_engine_id": self.remote_engine_id, + "notify_msg": self.notify_msg, + "op_type": self.op_type, } return tensor_dict @@ -316,13 +341,16 @@ class LocalOrDistributedWorkerBase(WorkerBase): return None worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) - model_input = ( - self.model_runner.make_model_input_from_broadcasted_tensor_dict( - broadcast_data)) + if worker_input.num_seq_groups > 0: + model_input = ( + self.model_runner.make_model_input_from_broadcasted_tensor_dict( + broadcast_data)) - kwargs = extract_previous_hidden_states(broadcast_data) + kwargs = extract_previous_hidden_states(broadcast_data) - return model_input, worker_input, kwargs + return model_input, worker_input, kwargs + else: + return None, worker_input, {} def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest @@ -396,49 +424,88 @@ class LocalOrDistributedWorkerBase(WorkerBase): self.execute_worker(worker_input) # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] + if worker_input.num_seq_groups > 0: - intermediate_tensors = None - orig_model_execute_time = 0.0 - if not get_pp_group().is_first_rank: - intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) + self._read_blocks(worker_input) + + intermediate_tensors = None + orig_model_execute_time = 0.0 + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + orig_model_execute_time = intermediate_tensors.tensors.get( + "model_execute_time", torch.tensor(0)).item() + + output = self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + **kwargs, + ) + + model_execute_time = time.perf_counter() - start_time + if not get_pp_group().is_last_rank: + # output is IntermediateTensors + assert isinstance(output, IntermediateTensors) + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + output.tensors["model_execute_time"] = torch.tensor( + model_execute_time + orig_model_execute_time) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + return [None] if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - orig_model_execute_time = intermediate_tensors.tensors.get( - "model_execute_time", torch.tensor(0)).item() + and self.observability_config.collect_model_execute_time + and output is not None): + for o in output: + o.model_execute_time = (orig_model_execute_time + + model_execute_time) - output = self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - num_steps=num_steps, - **kwargs, - ) + self._write_blocks(worker_input) - model_execute_time = time.perf_counter() - start_time - if not get_pp_group().is_last_rank: - # output is IntermediateTensors - assert isinstance(output, IntermediateTensors) - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - output.tensors["model_execute_time"] = torch.tensor( - model_execute_time + orig_model_execute_time) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - return [None] - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time - and output is not None): - for o in output: - o.model_execute_time = (orig_model_execute_time + - model_execute_time) + else: + output = [] + # collect kv transfer notifications from non driver workers + + if self.nixl_connector is not None: + new_notifs = self.nixl_connector.get_new_notifs() + rank = get_tp_group().rank + all_new_notifs = [new_notifs] + if rank > 0: + get_tp_group().send_object(new_notifs, dst=0) + else: + for i in range(1, get_tp_group().world_size): + all_new_notifs.append(get_tp_group().recv_object(src=i)) + + request_notif_counter = defaultdict(int) + for notifs in all_new_notifs: + for req_ids in notifs.values(): + for req_id in req_ids: + request_notif_counter[req_id] += 1 + + if request_notif_counter: + logger.debug("Request notif counter: %s", request_notif_counter) + + request_done_counter = defaultdict(int) + for req_id in self.nixl_connector.get_done_tranfers(): + request_done_counter[req_id] += 1 + else: + request_notif_counter = {} + request_done_counter = {} # output is List[SamplerOutput] - return output + return output, request_notif_counter, request_done_counter + + def _read_blocks(self, worker_input: WorkerInput) -> None: + pass + + def _write_blocks(self, worker_input: WorkerInput) -> None: + pass def _execute_model_spmd( self,