mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 13:51:54 +08:00
[Misc] Remove cache stream and cache events (#3461)
This commit is contained in:
parent
4ad521d8b5
commit
5ee14494e4
77
tests/worker/test_swap.py
Normal file
77
tests/worker/test_swap.py
Normal file
@ -0,0 +1,77 @@
|
||||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
|
||||
|
||||
def test_swap() -> None:
|
||||
# Configure the engine.
|
||||
engine_args = EngineArgs(model="facebook/opt-125m",
|
||||
dtype="half",
|
||||
load_format="dummy")
|
||||
(model_config, cache_config, parallel_config, scheduler_config,
|
||||
device_config, _) = engine_args.create_engine_configs()
|
||||
cache_config.num_gpu_blocks = 100
|
||||
cache_config.num_cpu_blocks = 100
|
||||
|
||||
# Create the worker.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
worker = Worker(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
# Initialize the worker.
|
||||
worker.init_model()
|
||||
worker.load_model()
|
||||
worker.init_cache_engine(cache_config)
|
||||
worker.warm_up_model()
|
||||
|
||||
# Randomly initialize the cache.
|
||||
gpu_cache = worker.cache_engine.gpu_cache
|
||||
cpu_cache = worker.cache_engine.cpu_cache
|
||||
num_layers = len(gpu_cache)
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
gpu_key_cache.random_()
|
||||
gpu_value_cache.random_()
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
cpu_key_cache.random_()
|
||||
cpu_value_cache.random_()
|
||||
|
||||
allclose = lambda a, b: torch.allclose(
|
||||
a.cuda(), b.cuda(), rtol=0.0, atol=0.0)
|
||||
|
||||
# Test swap out.
|
||||
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
|
||||
worker.execute_model(seq_group_metadata_list=[],
|
||||
blocks_to_swap_in={},
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy={})
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
for src, dst in blocks_to_swap_out.items():
|
||||
assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
|
||||
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
|
||||
|
||||
# Test swap in.
|
||||
blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71}
|
||||
worker.execute_model(seq_group_metadata_list=[],
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out={},
|
||||
blocks_to_copy={})
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
for src, dst in blocks_to_swap_in.items():
|
||||
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
|
||||
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
|
||||
@ -38,7 +38,7 @@ class CacheEngine:
|
||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
||||
|
||||
# Skip initializing CUDA stream and buffer for Neuron backend.
|
||||
# Skip initializing KV cache for Neuron backend.
|
||||
if is_neuron():
|
||||
return
|
||||
|
||||
@ -51,12 +51,6 @@ class CacheEngine:
|
||||
self.gpu_cache = self.allocate_gpu_cache()
|
||||
self.cpu_cache = self.allocate_cpu_cache()
|
||||
|
||||
# Initialize the stream for caching operations.
|
||||
self.cache_stream = torch.cuda.Stream()
|
||||
assert self.cache_stream != torch.cuda.current_stream()
|
||||
# Initialize the events for stream synchronization.
|
||||
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
|
||||
|
||||
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
|
||||
element_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||
x = 16 // element_size
|
||||
@ -126,17 +120,13 @@ class CacheEngine:
|
||||
) -> None:
|
||||
from vllm._C import cache_ops
|
||||
|
||||
with torch.cuda.stream(self.cache_stream):
|
||||
for i in range(self.num_layers):
|
||||
src_key_cache, src_value_cache = src[i]
|
||||
dst_key_cache, dst_value_cache = dst[i]
|
||||
# Copy the key blocks.
|
||||
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
# Copy the value blocks.
|
||||
cache_ops.swap_blocks(src_value_cache, dst_value_cache,
|
||||
src_to_dst)
|
||||
event = self.events[i]
|
||||
event.record(stream=self.cache_stream)
|
||||
for i in range(self.num_layers):
|
||||
src_key_cache, src_value_cache = src[i]
|
||||
dst_key_cache, dst_value_cache = dst[i]
|
||||
# Copy the key blocks.
|
||||
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
# Copy the value blocks.
|
||||
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
||||
self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
|
||||
|
||||
@ -65,7 +65,6 @@ class Worker:
|
||||
# self.init_cache_engine().
|
||||
self.cache_config = None
|
||||
self.cache_engine = None
|
||||
self.cache_events = None
|
||||
self.gpu_cache = None
|
||||
|
||||
def init_model(self, cupy_port: Optional[int] = None) -> None:
|
||||
@ -148,7 +147,6 @@ class Worker:
|
||||
self.cache_config = cache_config
|
||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config)
|
||||
self.cache_events = self.cache_engine.events
|
||||
self.gpu_cache = self.cache_engine.gpu_cache
|
||||
self.model_runner.set_block_size(self.cache_engine.block_size)
|
||||
|
||||
@ -166,24 +164,13 @@ class Worker:
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
# Issue cache operations.
|
||||
issued_cache_op = False
|
||||
# TODO(woosuk): Profile swapping overhead and optimize if needed.
|
||||
if blocks_to_swap_in:
|
||||
self.cache_engine.swap_in(blocks_to_swap_in)
|
||||
issued_cache_op = True
|
||||
if blocks_to_swap_out:
|
||||
self.cache_engine.swap_out(blocks_to_swap_out)
|
||||
issued_cache_op = True
|
||||
if blocks_to_copy:
|
||||
self.cache_engine.copy(blocks_to_copy)
|
||||
issued_cache_op = True
|
||||
|
||||
cache_events = self.cache_events if issued_cache_op else None
|
||||
|
||||
# Wait for cache operations to finish.
|
||||
# TODO(woosuk): Profile swapping overhead and optimize if needed.
|
||||
if cache_events is not None:
|
||||
for event in cache_events:
|
||||
event.wait()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user