diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py new file mode 100644 index 0000000000000..35630a06a900f --- /dev/null +++ b/tests/worker/test_swap.py @@ -0,0 +1,77 @@ +import torch + +from vllm.engine.arg_utils import EngineArgs +from vllm.worker.worker import Worker +from vllm.utils import get_distributed_init_method, get_ip, get_open_port + + +def test_swap() -> None: + # Configure the engine. + engine_args = EngineArgs(model="facebook/opt-125m", + dtype="half", + load_format="dummy") + (model_config, cache_config, parallel_config, scheduler_config, + device_config, _) = engine_args.create_engine_configs() + cache_config.num_gpu_blocks = 100 + cache_config.num_cpu_blocks = 100 + + # Create the worker. + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + worker = Worker( + model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + is_driver_worker=True, + ) + + # Initialize the worker. + worker.init_model() + worker.load_model() + worker.init_cache_engine(cache_config) + worker.warm_up_model() + + # Randomly initialize the cache. + gpu_cache = worker.cache_engine.gpu_cache + cpu_cache = worker.cache_engine.cpu_cache + num_layers = len(gpu_cache) + for i in range(num_layers): + gpu_key_cache, gpu_value_cache = gpu_cache[i] + gpu_key_cache.random_() + gpu_value_cache.random_() + cpu_key_cache, cpu_value_cache = cpu_cache[i] + cpu_key_cache.random_() + cpu_value_cache.random_() + + allclose = lambda a, b: torch.allclose( + a.cuda(), b.cuda(), rtol=0.0, atol=0.0) + + # Test swap out. + blocks_to_swap_out = {3: 72, 56: 35, 84: 34} + worker.execute_model(seq_group_metadata_list=[], + blocks_to_swap_in={}, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy={}) + for i in range(num_layers): + gpu_key_cache, gpu_value_cache = gpu_cache[i] + cpu_key_cache, cpu_value_cache = cpu_cache[i] + for src, dst in blocks_to_swap_out.items(): + assert allclose(gpu_key_cache[src], cpu_key_cache[dst]) + assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) + + # Test swap in. + blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71} + worker.execute_model(seq_group_metadata_list=[], + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out={}, + blocks_to_copy={}) + for i in range(num_layers): + gpu_key_cache, gpu_value_cache = gpu_cache[i] + cpu_key_cache, cpu_value_cache = cpu_cache[i] + for src, dst in blocks_to_swap_in.items(): + assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) + assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 880299783935c..1782fe7e57177 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -38,7 +38,7 @@ class CacheEngine: self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks - # Skip initializing CUDA stream and buffer for Neuron backend. + # Skip initializing KV cache for Neuron backend. if is_neuron(): return @@ -51,12 +51,6 @@ class CacheEngine: self.gpu_cache = self.allocate_gpu_cache() self.cpu_cache = self.allocate_cpu_cache() - # Initialize the stream for caching operations. - self.cache_stream = torch.cuda.Stream() - assert self.cache_stream != torch.cuda.current_stream() - # Initialize the events for stream synchronization. - self.events = [torch.cuda.Event() for _ in range(self.num_layers)] - def get_key_block_shape(self) -> Tuple[int, int, int, int]: element_size = torch.tensor([], dtype=self.dtype).element_size() x = 16 // element_size @@ -126,17 +120,13 @@ class CacheEngine: ) -> None: from vllm._C import cache_ops - with torch.cuda.stream(self.cache_stream): - for i in range(self.num_layers): - src_key_cache, src_value_cache = src[i] - dst_key_cache, dst_value_cache = dst[i] - # Copy the key blocks. - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - # Copy the value blocks. - cache_ops.swap_blocks(src_value_cache, dst_value_cache, - src_to_dst) - event = self.events[i] - event.record(stream=self.cache_stream) + for i in range(self.num_layers): + src_key_cache, src_value_cache = src[i] + dst_key_cache, dst_value_cache = dst[i] + # Copy the key blocks. + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + # Copy the value blocks. + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) def swap_in(self, src_to_dst: Dict[int, int]) -> None: self._swap(self.cpu_cache, self.gpu_cache, src_to_dst) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0dcd4018afa5f..81beb5ce4d8d4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -65,7 +65,6 @@ class Worker: # self.init_cache_engine(). self.cache_config = None self.cache_engine = None - self.cache_events = None self.gpu_cache = None def init_model(self, cupy_port: Optional[int] = None) -> None: @@ -148,7 +147,6 @@ class Worker: self.cache_config = cache_config self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) - self.cache_events = self.cache_engine.events self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) @@ -166,24 +164,13 @@ class Worker: blocks_to_copy: Dict[int, List[int]], ) -> None: # Issue cache operations. - issued_cache_op = False + # TODO(woosuk): Profile swapping overhead and optimize if needed. if blocks_to_swap_in: self.cache_engine.swap_in(blocks_to_swap_in) - issued_cache_op = True if blocks_to_swap_out: self.cache_engine.swap_out(blocks_to_swap_out) - issued_cache_op = True if blocks_to_copy: self.cache_engine.copy(blocks_to_copy) - issued_cache_op = True - - cache_events = self.cache_events if issued_cache_op else None - - # Wait for cache operations to finish. - # TODO(woosuk): Profile swapping overhead and optimize if needed. - if cache_events is not None: - for event in cache_events: - event.wait() @torch.inference_mode() def execute_model(