mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 02:35:01 +08:00
CPU KV Offloading: Use more CUDA streams (#29013)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
parent
9ccbf6b692
commit
174e39ead7
@ -9,7 +9,7 @@ import torch
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
||||||
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler
|
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers
|
||||||
|
|
||||||
BACKENDS_TO_TEST = [FlashAttentionBackend]
|
BACKENDS_TO_TEST = [FlashAttentionBackend]
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ def test_transfer(
|
|||||||
|
|
||||||
# create handler
|
# create handler
|
||||||
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
|
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
|
||||||
handler = CpuGpuOffloadingHandler(
|
handlers = CpuGpuOffloadingHandlers(
|
||||||
attn_backends=attn_backends,
|
attn_backends=attn_backends,
|
||||||
gpu_block_size=gpu_block_size,
|
gpu_block_size=gpu_block_size,
|
||||||
cpu_block_size=cpu_block_size,
|
cpu_block_size=cpu_block_size,
|
||||||
@ -112,8 +112,7 @@ def test_transfer(
|
|||||||
|
|
||||||
# set transfer direction
|
# set transfer direction
|
||||||
if gpu_to_cpu:
|
if gpu_to_cpu:
|
||||||
src_kv_caches = handler.gpu_tensors
|
handler = handlers.gpu_to_cpu_handler
|
||||||
dst_kv_caches = handler.cpu_tensors
|
|
||||||
src_spec_class = GPULoadStoreSpec
|
src_spec_class = GPULoadStoreSpec
|
||||||
dst_spec_class = CPULoadStoreSpec
|
dst_spec_class = CPULoadStoreSpec
|
||||||
src_blocks = gpu_blocks
|
src_blocks = gpu_blocks
|
||||||
@ -122,8 +121,7 @@ def test_transfer(
|
|||||||
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
|
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
|
||||||
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
|
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
|
||||||
else:
|
else:
|
||||||
src_kv_caches = handler.cpu_tensors
|
handler = handlers.cpu_to_gpu_handler
|
||||||
dst_kv_caches = handler.gpu_tensors
|
|
||||||
src_spec_class = CPULoadStoreSpec
|
src_spec_class = CPULoadStoreSpec
|
||||||
dst_spec_class = GPULoadStoreSpec
|
dst_spec_class = GPULoadStoreSpec
|
||||||
src_blocks = cpu_blocks
|
src_blocks = cpu_blocks
|
||||||
@ -144,12 +142,12 @@ def test_transfer(
|
|||||||
dst_spec = dst_spec_class(dst_blocks)
|
dst_spec = dst_spec_class(dst_blocks)
|
||||||
|
|
||||||
# clone src and dst tensors before transfer
|
# clone src and dst tensors before transfer
|
||||||
orig_src_caches = [x.clone() for x in src_kv_caches]
|
orig_src_caches = [x.clone() for x in handler.src_tensors]
|
||||||
orig_dst_caches = [x.clone() for x in dst_kv_caches]
|
orig_dst_caches = [x.clone() for x in handler.dst_tensors]
|
||||||
|
|
||||||
# call transfer function
|
# call transfer function
|
||||||
assert handler.transfer_async(1, (src_spec, dst_spec))
|
assert handler.transfer_async(1, (src_spec, dst_spec))
|
||||||
assert set(handler.transfer_events.keys()) == {1}
|
assert set({x[0] for x in handler._transfers}) == {1}
|
||||||
|
|
||||||
# wait for transfer to complete
|
# wait for transfer to complete
|
||||||
end_time = time.time() + 10
|
end_time = time.time() + 10
|
||||||
@ -161,15 +159,15 @@ def test_transfer(
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
# verify src tensors did not change
|
# verify src tensors did not change
|
||||||
for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches):
|
for orig_tensor, tensor in zip(orig_src_caches, handler.src_tensors):
|
||||||
assert torch.equal(orig_tensor, tensor)
|
assert torch.equal(orig_tensor, tensor)
|
||||||
|
|
||||||
# verify dst tensors
|
# verify dst tensors
|
||||||
for dst_block in range(dst_size_in_gpu_blocks):
|
for dst_block in range(dst_size_in_gpu_blocks):
|
||||||
src_block_candidate = dst_to_src.get(dst_block)
|
src_block_candidate = dst_to_src.get(dst_block)
|
||||||
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
|
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
|
||||||
src_kv_caches,
|
handler.src_tensors,
|
||||||
dst_kv_caches,
|
handler.dst_tensors,
|
||||||
orig_dst_caches,
|
orig_dst_caches,
|
||||||
handler.kv_dim_before_num_blocks,
|
handler.kv_dim_before_num_blocks,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.v1.kv_offload.backends.cpu import CPUBackend
|
|||||||
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
|
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
|
||||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
||||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||||
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler
|
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers
|
||||||
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
||||||
|
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ class CPUOffloadingSpec(OffloadingSpec):
|
|||||||
self._manager: OffloadingManager | None = None
|
self._manager: OffloadingManager | None = None
|
||||||
|
|
||||||
# worker-side
|
# worker-side
|
||||||
self._handler: OffloadingHandler | None = None
|
self._handlers: CpuGpuOffloadingHandlers | None = None
|
||||||
|
|
||||||
self.eviction_policy: str = self.extra_config.get("eviction_policy", "lru")
|
self.eviction_policy: str = self.extra_config.get("eviction_policy", "lru")
|
||||||
|
|
||||||
@ -67,13 +67,13 @@ class CPUOffloadingSpec(OffloadingSpec):
|
|||||||
kv_caches: dict[str, torch.Tensor],
|
kv_caches: dict[str, torch.Tensor],
|
||||||
attn_backends: dict[str, type[AttentionBackend]],
|
attn_backends: dict[str, type[AttentionBackend]],
|
||||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||||
if not self._handler:
|
if not self._handlers:
|
||||||
if not current_platform.is_cuda_alike():
|
if not current_platform.is_cuda_alike():
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"CPU Offloading is currently only supported on CUDA-alike GPUs"
|
"CPU Offloading is currently only supported on CUDA-alike GPUs"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._handler = CpuGpuOffloadingHandler(
|
self._handlers = CpuGpuOffloadingHandlers(
|
||||||
attn_backends=attn_backends,
|
attn_backends=attn_backends,
|
||||||
gpu_block_size=self.gpu_block_size,
|
gpu_block_size=self.gpu_block_size,
|
||||||
cpu_block_size=self.offloaded_block_size,
|
cpu_block_size=self.offloaded_block_size,
|
||||||
@ -81,6 +81,6 @@ class CPUOffloadingSpec(OffloadingSpec):
|
|||||||
gpu_caches=kv_caches,
|
gpu_caches=kv_caches,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert self._handler is not None
|
assert self._handlers is not None
|
||||||
yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler
|
yield GPULoadStoreSpec, CPULoadStoreSpec, self._handlers.gpu_to_cpu_handler
|
||||||
yield CPULoadStoreSpec, GPULoadStoreSpec, self._handler
|
yield CPULoadStoreSpec, GPULoadStoreSpec, self._handlers.cpu_to_gpu_handler
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -8,7 +9,7 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
from vllm.v1.kv_offload.mediums import BlockIDsLoadStoreSpec
|
||||||
from vllm.v1.kv_offload.worker.worker import (
|
from vllm.v1.kv_offload.worker.worker import (
|
||||||
OffloadingHandler,
|
OffloadingHandler,
|
||||||
TransferResult,
|
TransferResult,
|
||||||
@ -51,7 +52,123 @@ def expand_block_ids(
|
|||||||
output_idx = output_end_idx
|
output_idx = output_end_idx
|
||||||
|
|
||||||
|
|
||||||
class CpuGpuOffloadingHandler(OffloadingHandler):
|
class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||||
|
"""
|
||||||
|
SingleDirectionOffloadingHandler handles transfers for a single direction,
|
||||||
|
either CPU->GPU or GPU->CPU.
|
||||||
|
Transfers are guaranteed to be executed in order of their submission.
|
||||||
|
Each transfer uses a unique CUDA stream, and its stream will start
|
||||||
|
executing only after the streams of previous transfers have finished.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
src_tensors: list[torch.Tensor],
|
||||||
|
dst_tensors: list[torch.Tensor],
|
||||||
|
kv_dim_before_num_blocks: list[bool],
|
||||||
|
src_block_size_factor: int,
|
||||||
|
dst_block_size_factor: int,
|
||||||
|
priority: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize a SingleDirectionOffloadingHandler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src_tensors: list of KV cache tensors to copy from.
|
||||||
|
dst_tensors: list of KV cache tensors to copy to.
|
||||||
|
Order should match src_tensors.
|
||||||
|
kv_dim_before_num_blocks: list of bools, indicating
|
||||||
|
whether the respective KV cache tensor has a KV
|
||||||
|
dimension before its num_blocks dimension.
|
||||||
|
e.g. (2, num_blocks, ...)
|
||||||
|
src_block_size_factor: The number of kernel blocks
|
||||||
|
per KV block in a source tensor.
|
||||||
|
dst_block_size_factor: The number of kernel blocks
|
||||||
|
per KV block in a destination tensor.
|
||||||
|
priority: The priority of the backing CUDA streams.
|
||||||
|
Lower numbers indicate higher priority.
|
||||||
|
"""
|
||||||
|
assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks)
|
||||||
|
|
||||||
|
self.src_tensors: list[torch.Tensor] = src_tensors
|
||||||
|
self.dst_tensors: list[torch.Tensor] = dst_tensors
|
||||||
|
self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks
|
||||||
|
self.src_block_size_factor: int = src_block_size_factor
|
||||||
|
self.dst_block_size_factor: int = dst_block_size_factor
|
||||||
|
self.priority = priority
|
||||||
|
|
||||||
|
# queue of transfers (job_id, stream, event)
|
||||||
|
self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque()
|
||||||
|
# list of CUDA streams available for re-use
|
||||||
|
self._stream_pool: list[torch.cuda.Stream] = []
|
||||||
|
# list of CUDA events available for re-use
|
||||||
|
self._event_pool: list[torch.Event] = []
|
||||||
|
|
||||||
|
def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
|
||||||
|
src_spec, dst_spec = transfer_spec
|
||||||
|
assert isinstance(src_spec, BlockIDsLoadStoreSpec)
|
||||||
|
assert isinstance(dst_spec, BlockIDsLoadStoreSpec)
|
||||||
|
|
||||||
|
src_blocks = src_spec.block_ids
|
||||||
|
dst_blocks = dst_spec.block_ids
|
||||||
|
assert src_blocks.ndim == 1
|
||||||
|
assert dst_blocks.ndim == 1
|
||||||
|
|
||||||
|
src_sub_block_count = src_blocks.size * self.src_block_size_factor
|
||||||
|
dst_sub_block_count = dst_blocks.size * self.dst_block_size_factor
|
||||||
|
src_sub_blocks_to_skip = -dst_blocks.size % self.src_block_size_factor
|
||||||
|
|
||||||
|
assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip
|
||||||
|
|
||||||
|
src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64)
|
||||||
|
expand_block_ids(
|
||||||
|
src_blocks,
|
||||||
|
self.src_block_size_factor,
|
||||||
|
src_to_dst[:, 0],
|
||||||
|
skip_count=src_sub_blocks_to_skip,
|
||||||
|
)
|
||||||
|
expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1])
|
||||||
|
src_to_dst_tensor = torch.from_numpy(src_to_dst)
|
||||||
|
|
||||||
|
stream = (
|
||||||
|
self._stream_pool.pop()
|
||||||
|
if self._stream_pool
|
||||||
|
else torch.cuda.Stream(priority=self.priority)
|
||||||
|
)
|
||||||
|
event = self._event_pool.pop() if self._event_pool else torch.Event()
|
||||||
|
if self._transfers:
|
||||||
|
_, _, last_event = self._transfers[-1]
|
||||||
|
# assure job will start only after the previous one completes
|
||||||
|
stream.wait_event(last_event)
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
for src_tensor, dst_tensor, kv_dim in zip(
|
||||||
|
self.src_tensors, self.dst_tensors, self.kv_dim_before_num_blocks
|
||||||
|
):
|
||||||
|
if kv_dim:
|
||||||
|
src_key_cache, src_value_cache = src_tensor
|
||||||
|
dst_key_cache, dst_value_cache = dst_tensor
|
||||||
|
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
|
||||||
|
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
|
||||||
|
else:
|
||||||
|
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
|
||||||
|
event.record(stream)
|
||||||
|
|
||||||
|
self._transfers.append((job_id, stream, event))
|
||||||
|
|
||||||
|
# success
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_finished(self) -> list[TransferResult]:
|
||||||
|
results: list[TransferResult] = []
|
||||||
|
while self._transfers and self._transfers[0][2].query():
|
||||||
|
job_id, stream, event = self._transfers.popleft()
|
||||||
|
results.append((job_id, True))
|
||||||
|
self._stream_pool.append(stream)
|
||||||
|
self._event_pool.append(event)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class CpuGpuOffloadingHandlers:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpu_block_size: int,
|
gpu_block_size: int,
|
||||||
@ -60,27 +177,20 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
|
|||||||
gpu_caches: dict[str, torch.Tensor],
|
gpu_caches: dict[str, torch.Tensor],
|
||||||
attn_backends: dict[str, type[AttentionBackend]],
|
attn_backends: dict[str, type[AttentionBackend]],
|
||||||
):
|
):
|
||||||
|
assert gpu_caches
|
||||||
assert cpu_block_size % gpu_block_size == 0
|
assert cpu_block_size % gpu_block_size == 0
|
||||||
self.block_size_factor = cpu_block_size // gpu_block_size
|
block_size_factor = cpu_block_size // gpu_block_size
|
||||||
|
|
||||||
# cuda streams for gpu->cpu and cpu->gpu
|
|
||||||
self.d2h_stream = torch.cuda.Stream()
|
|
||||||
self.h2d_stream = torch.cuda.Stream()
|
|
||||||
|
|
||||||
# job_id -> transfer cuda event
|
|
||||||
self.transfer_events: dict[int, torch.Event] = {}
|
|
||||||
# list of cuda events available for re-use
|
|
||||||
self.events_pool: list[torch.Event] = []
|
|
||||||
|
|
||||||
pin_memory = is_pin_memory_available()
|
pin_memory = is_pin_memory_available()
|
||||||
|
|
||||||
# allocate cpu tensors
|
# allocate cpu tensors
|
||||||
logger.info("Allocating %d CPU tensors...", len(gpu_caches))
|
logger.info("Allocating %d CPU tensors...", len(gpu_caches))
|
||||||
self.gpu_tensors: list[torch.Tensor] = []
|
gpu_tensors: list[torch.Tensor] = []
|
||||||
self.cpu_tensors: list[torch.Tensor] = []
|
cpu_tensors: list[torch.Tensor] = []
|
||||||
self.kv_dim_before_num_blocks: list[bool] = []
|
kv_dim_before_num_blocks: list[bool] = []
|
||||||
|
kernel_block_size: int | None = None
|
||||||
for layer_name, gpu_tensor in gpu_caches.items():
|
for layer_name, gpu_tensor in gpu_caches.items():
|
||||||
self.gpu_tensors.append(gpu_tensor)
|
gpu_tensors.append(gpu_tensor)
|
||||||
|
|
||||||
gpu_shape = gpu_tensor.shape
|
gpu_shape = gpu_tensor.shape
|
||||||
attn_backend = attn_backends[layer_name]
|
attn_backend = attn_backends[layer_name]
|
||||||
@ -88,16 +198,21 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
|
|||||||
num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256
|
num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256
|
||||||
)
|
)
|
||||||
|
|
||||||
|
has_layers_dim = False
|
||||||
if len(gpu_shape) != len(test_shape):
|
if len(gpu_shape) != len(test_shape):
|
||||||
# cross-layers tensor
|
# cross-layers tensor
|
||||||
# shape is (num_blocks, ...)
|
# shape is (num_blocks, ...)
|
||||||
assert len(gpu_shape) == len(test_shape) + 1
|
assert len(gpu_shape) == len(test_shape) + 1
|
||||||
num_blocks_idx = 0
|
num_blocks_idx = 0
|
||||||
self.kv_dim_before_num_blocks.append(False)
|
has_layers_dim = True
|
||||||
|
kv_dim_before_num_blocks.append(False)
|
||||||
|
|
||||||
|
# prepend a dummy num_layers=80 to test_shape
|
||||||
|
test_shape = (80,) + test_shape
|
||||||
elif test_shape[0] == 1234:
|
elif test_shape[0] == 1234:
|
||||||
# shape is (num_blocks, ...)
|
# shape is (num_blocks, ...)
|
||||||
num_blocks_idx = 0
|
num_blocks_idx = 0
|
||||||
self.kv_dim_before_num_blocks.append(False)
|
kv_dim_before_num_blocks.append(False)
|
||||||
else:
|
else:
|
||||||
# shape should be (2, num_blocks, ...)
|
# shape should be (2, num_blocks, ...)
|
||||||
assert test_shape[0] == 2
|
assert test_shape[0] == 2
|
||||||
@ -105,13 +220,32 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
|
|||||||
assert gpu_shape[0] == 2
|
assert gpu_shape[0] == 2
|
||||||
|
|
||||||
num_blocks_idx = 1
|
num_blocks_idx = 1
|
||||||
self.kv_dim_before_num_blocks.append(True)
|
kv_dim_before_num_blocks.append(True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||||
|
include_num_layers_dimension=has_layers_dim
|
||||||
|
)
|
||||||
|
assert len(kv_cache_stride_order) == len(gpu_shape)
|
||||||
|
except (AttributeError, NotImplementedError):
|
||||||
|
kv_cache_stride_order = tuple(range(len(gpu_shape)))
|
||||||
|
|
||||||
|
# permute test_shape according to stride_order
|
||||||
|
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order)
|
||||||
|
|
||||||
|
# find block_size (16) dimension index
|
||||||
|
block_size_idx = test_shape.index(16)
|
||||||
|
if kernel_block_size is not None:
|
||||||
|
assert kernel_block_size == gpu_shape[block_size_idx]
|
||||||
|
else:
|
||||||
|
kernel_block_size = gpu_shape[block_size_idx]
|
||||||
|
assert gpu_block_size % kernel_block_size == 0
|
||||||
|
|
||||||
cpu_shape = list(gpu_shape)
|
cpu_shape = list(gpu_shape)
|
||||||
cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor
|
cpu_shape[num_blocks_idx] = num_cpu_blocks * block_size_factor
|
||||||
|
|
||||||
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
|
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
|
||||||
self.cpu_tensors.append(
|
cpu_tensors.append(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
cpu_shape,
|
cpu_shape,
|
||||||
dtype=gpu_tensor.dtype,
|
dtype=gpu_tensor.dtype,
|
||||||
@ -120,72 +254,27 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
assert kernel_block_size is not None
|
||||||
src_spec, dst_spec = spec
|
gpu_block_size_factor = gpu_block_size // kernel_block_size
|
||||||
if isinstance(src_spec, CPULoadStoreSpec):
|
cpu_block_size_factor = cpu_block_size // kernel_block_size
|
||||||
assert isinstance(dst_spec, GPULoadStoreSpec)
|
|
||||||
stream = self.h2d_stream
|
|
||||||
src_tensors = self.cpu_tensors
|
|
||||||
dst_tensors = self.gpu_tensors
|
|
||||||
src_block_size_factor = self.block_size_factor
|
|
||||||
dst_block_size_factor = 1
|
|
||||||
else:
|
|
||||||
assert isinstance(src_spec, GPULoadStoreSpec)
|
|
||||||
assert isinstance(dst_spec, CPULoadStoreSpec)
|
|
||||||
stream = self.d2h_stream
|
|
||||||
src_tensors = self.gpu_tensors
|
|
||||||
dst_tensors = self.cpu_tensors
|
|
||||||
src_block_size_factor = 1
|
|
||||||
dst_block_size_factor = self.block_size_factor
|
|
||||||
|
|
||||||
src_blocks = src_spec.block_ids
|
# TODO (orozery): adapt swap_blocks to support gpu_block_size_factor
|
||||||
dst_blocks = dst_spec.block_ids
|
assert gpu_block_size_factor == 1
|
||||||
assert src_blocks.ndim == 1
|
|
||||||
assert dst_blocks.ndim == 1
|
|
||||||
|
|
||||||
src_sub_block_count = src_blocks.size * src_block_size_factor
|
self.gpu_to_cpu_handler = SingleDirectionOffloadingHandler(
|
||||||
dst_sub_block_count = dst_blocks.size * dst_block_size_factor
|
src_tensors=gpu_tensors,
|
||||||
src_sub_blocks_to_skip = -dst_blocks.size % src_block_size_factor
|
dst_tensors=cpu_tensors,
|
||||||
|
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
|
||||||
assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip
|
src_block_size_factor=gpu_block_size_factor,
|
||||||
|
dst_block_size_factor=cpu_block_size_factor,
|
||||||
src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64)
|
priority=1,
|
||||||
expand_block_ids(
|
|
||||||
src_blocks,
|
|
||||||
src_block_size_factor,
|
|
||||||
src_to_dst[:, 0],
|
|
||||||
skip_count=src_sub_blocks_to_skip,
|
|
||||||
)
|
)
|
||||||
expand_block_ids(dst_blocks, dst_block_size_factor, src_to_dst[:, 1])
|
|
||||||
src_to_dst_tensor = torch.from_numpy(src_to_dst)
|
|
||||||
|
|
||||||
event = self.events_pool.pop() if self.events_pool else torch.Event()
|
self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler(
|
||||||
with torch.cuda.stream(stream):
|
src_tensors=cpu_tensors,
|
||||||
for src_tensor, dst_tensor, kv_dim in zip(
|
dst_tensors=gpu_tensors,
|
||||||
src_tensors, dst_tensors, self.kv_dim_before_num_blocks
|
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
|
||||||
):
|
src_block_size_factor=cpu_block_size_factor,
|
||||||
if kv_dim:
|
dst_block_size_factor=gpu_block_size_factor,
|
||||||
src_key_cache = src_tensor[0]
|
priority=-1,
|
||||||
dst_key_cache = dst_tensor[0]
|
)
|
||||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
|
|
||||||
src_value_cache = src_tensor[1]
|
|
||||||
dst_value_cache = dst_tensor[1]
|
|
||||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
|
|
||||||
else:
|
|
||||||
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
|
|
||||||
event.record(stream)
|
|
||||||
|
|
||||||
self.transfer_events[job_id] = event
|
|
||||||
|
|
||||||
# success
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_finished(self) -> list[TransferResult]:
|
|
||||||
results: list[TransferResult] = []
|
|
||||||
for job_id, event in self.transfer_events.items():
|
|
||||||
if event.query():
|
|
||||||
results.append((job_id, True))
|
|
||||||
self.events_pool.append(event)
|
|
||||||
for job_id, _ in results:
|
|
||||||
del self.transfer_events[job_id]
|
|
||||||
return results
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user