mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 02:25:36 +08:00
[KVConnector] Remove v0-related kv connector components such as kv pipe and kv lookup buffer (#29705)
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
parent
652ba93da3
commit
ece2825a29
@ -1,160 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from vllm.config import KVTransferConfig
|
|
||||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import SimpleBuffer
|
|
||||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
|
|
||||||
|
|
||||||
# TODO: the test depends on a lot of fields in the current implementation.
|
|
||||||
# We should have standard interface instead direct field access
|
|
||||||
|
|
||||||
|
|
||||||
def test_run(my_rank, buffer, device):
|
|
||||||
# buffer should be empty in the beginning
|
|
||||||
if my_rank == 0:
|
|
||||||
assert buffer.buffer_size == 0
|
|
||||||
assert len(buffer.buffer) == 0
|
|
||||||
|
|
||||||
print(f"My rank: {my_rank}, device: {device}")
|
|
||||||
|
|
||||||
# insert
|
|
||||||
tokens = torch.tensor([1, 2, 3]).to(device)
|
|
||||||
roi = tokens > 0
|
|
||||||
if my_rank == 0:
|
|
||||||
key = 2.0 * torch.ones([5, 6]).to(device)
|
|
||||||
value = 3.0 * torch.ones([5, 6]).to(device)
|
|
||||||
|
|
||||||
placeholder = torch.tensor([1]).to(device)
|
|
||||||
|
|
||||||
buffer.insert(tokens, roi, key, value, placeholder)
|
|
||||||
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
# drop_select
|
|
||||||
if my_rank == 1:
|
|
||||||
tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi)
|
|
||||||
assert torch.allclose(tokens, tok)
|
|
||||||
assert torch.allclose(roi, roi_)
|
|
||||||
assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device))
|
|
||||||
assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device))
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
if my_rank == 0:
|
|
||||||
assert buffer.buffer_size == 0
|
|
||||||
assert len(buffer.buffer) == 0
|
|
||||||
|
|
||||||
print(f"My rank: {my_rank}, Test run passed!")
|
|
||||||
|
|
||||||
|
|
||||||
def stress_test(my_rank, buf, device):
|
|
||||||
torch.distributed.barrier()
|
|
||||||
torch.manual_seed(100)
|
|
||||||
|
|
||||||
reqs = [
|
|
||||||
(
|
|
||||||
torch.rand(100).to(device), # tokens
|
|
||||||
torch.ones(100).bool().to(device), # roi
|
|
||||||
torch.rand(100).to(device), # key
|
|
||||||
torch.rand(100).to(device), # value
|
|
||||||
torch.rand(100).to(device), # hidden
|
|
||||||
)
|
|
||||||
for i in tqdm(range(200))
|
|
||||||
]
|
|
||||||
|
|
||||||
random.seed(my_rank)
|
|
||||||
random.shuffle(reqs)
|
|
||||||
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
n = 0
|
|
||||||
|
|
||||||
# the buffer size can only store 100 reqs
|
|
||||||
# so the sender will occasionally block to wait for the receiver.
|
|
||||||
for req in tqdm(reqs):
|
|
||||||
if my_rank == 0:
|
|
||||||
buf.insert(*req)
|
|
||||||
else:
|
|
||||||
tok, roi, k, v, h = req
|
|
||||||
tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi)
|
|
||||||
|
|
||||||
if tok_ is None:
|
|
||||||
assert roi_ is None
|
|
||||||
assert k_ is None
|
|
||||||
assert v_ is None
|
|
||||||
assert h_ is None
|
|
||||||
n += 1
|
|
||||||
else:
|
|
||||||
assert torch.allclose(tok, tok_)
|
|
||||||
assert torch.allclose(roi, roi_)
|
|
||||||
assert torch.allclose(k, k_)
|
|
||||||
assert torch.allclose(v, v_)
|
|
||||||
assert torch.allclose(h, h_)
|
|
||||||
print(f"Rank {my_rank} done")
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
if my_rank == 0:
|
|
||||||
x = torch.tensor([0])
|
|
||||||
torch.distributed.recv(x, 1)
|
|
||||||
# the # of None received is the kv that are not selected
|
|
||||||
assert x.item() == len(buf.buffer)
|
|
||||||
# and the size of the buffer should be 2000 * buffer len
|
|
||||||
print(buf.buffer_size)
|
|
||||||
assert buf.buffer_size == 1700 * len(buf.buffer)
|
|
||||||
else:
|
|
||||||
torch.distributed.send(torch.tensor([n]), 0)
|
|
||||||
|
|
||||||
print(f"My rank: {my_rank}, Passed stress test!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
my_rank = int(os.environ["RANK"])
|
|
||||||
|
|
||||||
torch.distributed.init_process_group(
|
|
||||||
backend="gloo",
|
|
||||||
init_method="tcp://localhost:12398",
|
|
||||||
world_size=2,
|
|
||||||
rank=my_rank,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"initialized! My rank is {my_rank}")
|
|
||||||
|
|
||||||
config = KVTransferConfig(
|
|
||||||
kv_connector="P2pNcclConnector",
|
|
||||||
kv_buffer_device="cuda",
|
|
||||||
kv_buffer_size=1e9,
|
|
||||||
kv_rank=my_rank,
|
|
||||||
kv_role="kv_both", # this arg doesn't matter in this test
|
|
||||||
kv_parallel_size=2,
|
|
||||||
kv_ip="127.0.0.1",
|
|
||||||
kv_port=12345,
|
|
||||||
)
|
|
||||||
|
|
||||||
data_pipe = PyNcclPipe(
|
|
||||||
local_rank=my_rank,
|
|
||||||
config=config,
|
|
||||||
device="cuda",
|
|
||||||
port_offset=0,
|
|
||||||
)
|
|
||||||
cpu_pipe = PyNcclPipe(
|
|
||||||
local_rank=my_rank,
|
|
||||||
config=config,
|
|
||||||
device="cpu",
|
|
||||||
port_offset=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000)
|
|
||||||
|
|
||||||
test_run(my_rank, buffer, data_pipe.device)
|
|
||||||
|
|
||||||
stress_test(my_rank, buffer, data_pipe.device)
|
|
||||||
|
|
||||||
buffer.close()
|
|
||||||
data_pipe.close()
|
|
||||||
cpu_pipe.close()
|
|
||||||
print("Done")
|
|
||||||
@ -1,8 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
RANK=0 python3 test_lookup_buffer.py &
|
|
||||||
PID0=$!
|
|
||||||
RANK=1 python3 test_lookup_buffer.py &
|
|
||||||
PID1=$!
|
|
||||||
|
|
||||||
wait $PID0
|
|
||||||
wait $PID1
|
|
||||||
@ -1,62 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def run_python_script(script_name, timeout):
|
|
||||||
script_name = f"kv_transfer/{script_name}"
|
|
||||||
try:
|
|
||||||
# Start both processes asynchronously using Popen
|
|
||||||
process0 = subprocess.Popen(
|
|
||||||
[sys.executable, script_name],
|
|
||||||
env={"RANK": "0"}, # Set the RANK environment variable for process 0
|
|
||||||
stdout=sys.stdout, # Pipe stdout to current stdout
|
|
||||||
stderr=sys.stderr, # Pipe stderr to current stderr
|
|
||||||
)
|
|
||||||
|
|
||||||
process1 = subprocess.Popen(
|
|
||||||
[sys.executable, script_name],
|
|
||||||
env={"RANK": "1"}, # Set the RANK environment variable for process 1
|
|
||||||
stdout=sys.stdout, # Pipe stdout to current stdout
|
|
||||||
stderr=sys.stderr, # Pipe stderr to current stderr
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for both processes to complete, with a timeout
|
|
||||||
process0.wait(timeout=timeout)
|
|
||||||
process1.wait(timeout=timeout)
|
|
||||||
|
|
||||||
# Check the return status of both processes
|
|
||||||
if process0.returncode != 0:
|
|
||||||
pytest.fail(f"Test {script_name} failed for RANK=0, {process0.returncode}")
|
|
||||||
if process1.returncode != 0:
|
|
||||||
pytest.fail(f"Test {script_name} failed for RANK=1, {process1.returncode}")
|
|
||||||
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
# If either process times out, terminate both and fail the test
|
|
||||||
process0.terminate()
|
|
||||||
process1.terminate()
|
|
||||||
pytest.fail(f"Test {script_name} timed out")
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Test {script_name} failed with error: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
# Define the test cases using pytest's parametrize
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"script_name,timeout",
|
|
||||||
[
|
|
||||||
("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout
|
|
||||||
("test_send_recv.py", 120), # First test case with a 120-second timeout
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_run_python_script(script_name, timeout):
|
|
||||||
# Check the number of GPUs
|
|
||||||
if torch.cuda.device_count() < 2:
|
|
||||||
pytest.skip(f"Skipping test {script_name} because <2 GPUs are available")
|
|
||||||
|
|
||||||
# Run the test if there are at least 2 GPUs
|
|
||||||
run_python_script(script_name, timeout)
|
|
||||||
@ -1,154 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from vllm.config import KVTransferConfig
|
|
||||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
|
|
||||||
|
|
||||||
|
|
||||||
def test_run(my_rank, pipe):
|
|
||||||
print(f"rank {my_rank} test_run starts....")
|
|
||||||
# test run
|
|
||||||
x = torch.tensor([1]).to(pipe.device)
|
|
||||||
y = torch.tensor([[2.0, 3.0, 4.0, 8.0]]).to(pipe.device)
|
|
||||||
if my_rank == 0:
|
|
||||||
pipe.send_tensor(x)
|
|
||||||
print(f"rank {my_rank} sent tensor x")
|
|
||||||
pipe.send_tensor(y)
|
|
||||||
print(f"rank {my_rank} sent tensor y")
|
|
||||||
x2 = pipe.recv_tensor()
|
|
||||||
print(f"rank {my_rank} received x2 = ", x2)
|
|
||||||
y2 = pipe.recv_tensor()
|
|
||||||
print(f"rank {my_rank} received y2 = ", y2)
|
|
||||||
|
|
||||||
else:
|
|
||||||
x2 = pipe.recv_tensor()
|
|
||||||
print(f"rank {my_rank} received x2 = ", x2)
|
|
||||||
y2 = pipe.recv_tensor()
|
|
||||||
print(f"rank {my_rank} received y2 = ", y2)
|
|
||||||
pipe.send_tensor(x)
|
|
||||||
print(f"rank {my_rank} sent tensor x")
|
|
||||||
pipe.send_tensor(y)
|
|
||||||
print(f"rank {my_rank} sent tensor y")
|
|
||||||
|
|
||||||
assert torch.allclose(x, x2)
|
|
||||||
assert torch.allclose(y, y2)
|
|
||||||
|
|
||||||
print(f"rank {my_rank} test_run passed!")
|
|
||||||
|
|
||||||
|
|
||||||
def stress_test(my_rank, pipe):
|
|
||||||
print(f"rank {my_rank} stress_test starts....")
|
|
||||||
|
|
||||||
tensors: list[torch.Tensor] = []
|
|
||||||
|
|
||||||
torch.distributed.barrier()
|
|
||||||
torch.manual_seed(0)
|
|
||||||
|
|
||||||
for i in tqdm(range(500)):
|
|
||||||
mean = torch.rand(1).item() * 100
|
|
||||||
std = torch.rand(1).item() * 100
|
|
||||||
size = torch.randint(900, 1000, (2,))
|
|
||||||
x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device)
|
|
||||||
|
|
||||||
# 5% probability of sending a None
|
|
||||||
if torch.rand(1).item() < 0.05:
|
|
||||||
tensors.append(None)
|
|
||||||
tensors.append(None)
|
|
||||||
tensors.append(None)
|
|
||||||
else:
|
|
||||||
tensors.append(x)
|
|
||||||
tensors.append(x.mean().unsqueeze(0))
|
|
||||||
tensors.append(x.std().unsqueeze(0))
|
|
||||||
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
for i in tqdm(range(500)):
|
|
||||||
if my_rank == int((i % 10) > 3):
|
|
||||||
pipe.send_tensor(tensors[3 * i])
|
|
||||||
pipe.send_tensor(tensors[3 * i + 1])
|
|
||||||
pipe.send_tensor(tensors[3 * i + 2])
|
|
||||||
else:
|
|
||||||
x = pipe.recv_tensor()
|
|
||||||
mean = pipe.recv_tensor()
|
|
||||||
std = pipe.recv_tensor()
|
|
||||||
|
|
||||||
if x is None:
|
|
||||||
assert mean is None
|
|
||||||
assert std is None
|
|
||||||
else:
|
|
||||||
assert torch.allclose(x, tensors[3 * i])
|
|
||||||
assert x.mean() == mean[0]
|
|
||||||
assert x.std() == std[0]
|
|
||||||
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
|
|
||||||
def latency_test(my_rank, pipe, nelement, ntensor):
|
|
||||||
latencies = []
|
|
||||||
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
for i in tqdm(range(500)):
|
|
||||||
tensors = []
|
|
||||||
|
|
||||||
if my_rank == 0:
|
|
||||||
# create tensor
|
|
||||||
tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)]
|
|
||||||
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
if my_rank == 0:
|
|
||||||
t = torch.tensor([time.time()], dtype=torch.float64).to(pipe.device)
|
|
||||||
for tensor in tensors:
|
|
||||||
pipe.send_tensor(tensor)
|
|
||||||
pipe.send_tensor(t)
|
|
||||||
else:
|
|
||||||
for _ in range(ntensor):
|
|
||||||
pipe.recv_tensor()
|
|
||||||
t = pipe.recv_tensor()
|
|
||||||
latencies.append(time.time() - t.item())
|
|
||||||
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
print("Latency test passed.")
|
|
||||||
print("Latency:", torch.tensor(latencies).mean().item() * 1000, "ms")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
my_rank = int(os.environ["RANK"])
|
|
||||||
|
|
||||||
torch.distributed.init_process_group(
|
|
||||||
backend="gloo",
|
|
||||||
init_method="tcp://localhost:12398",
|
|
||||||
world_size=2,
|
|
||||||
rank=my_rank,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = KVTransferConfig(
|
|
||||||
kv_connector="P2pNcclConnector",
|
|
||||||
kv_buffer_device="cuda",
|
|
||||||
kv_buffer_size=1e9,
|
|
||||||
kv_rank=my_rank,
|
|
||||||
kv_role="kv_both", # this arg doesn't matter in this test
|
|
||||||
kv_parallel_size=2,
|
|
||||||
kv_ip="127.0.0.1",
|
|
||||||
kv_port=12345,
|
|
||||||
)
|
|
||||||
|
|
||||||
pipe = PyNcclPipe(
|
|
||||||
local_rank=my_rank,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_run(my_rank, pipe)
|
|
||||||
|
|
||||||
stress_test(my_rank, pipe)
|
|
||||||
|
|
||||||
# Use this function if you want to test the latency of pipe impl.
|
|
||||||
# latency_test(my_rank, pipe, 1024 * 8 * 128, 80)
|
|
||||||
@ -1,9 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
RANK=0 python3 test_send_recv.py &
|
|
||||||
PID0=$!
|
|
||||||
RANK=1 python3 test_send_recv.py &
|
|
||||||
PID1=$!
|
|
||||||
|
|
||||||
wait $PID0
|
|
||||||
wait $PID1
|
|
||||||
@ -1,179 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""
|
|
||||||
This file contains a new class `KVLookupBufferBase` that allows developers to
|
|
||||||
think of KV cache operations as inserting new KV cache entries (`insert`)
|
|
||||||
into the lookup buffer and querying existing KV caches (`drop_select`)
|
|
||||||
from the lookup buffer.
|
|
||||||
|
|
||||||
This file also contains a new class `KVStoreBufferBase` that allows developers
|
|
||||||
to manage the KVCache buffer as a simple key-value storage buffer with basic
|
|
||||||
put/get operations.
|
|
||||||
|
|
||||||
These classes above are abstracted behind class `KVCacheBufferBase`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class KVCacheBufferBase(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for a KVCache buffer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Close the buffer and release resources.
|
|
||||||
|
|
||||||
This method is responsible for cleaning up resources related to the
|
|
||||||
KVCache buffer when it is no longer needed.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: This method must be implemented in subclasses.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class KVLookupBufferBase(KVCacheBufferBase):
|
|
||||||
"""
|
|
||||||
Abstract base class for a KVCache lookup buffer.
|
|
||||||
|
|
||||||
This class provides an abstraction for a key-value (KV) cache lookup buffer.
|
|
||||||
|
|
||||||
The key of the lookup buffer:
|
|
||||||
- input_tokens: token IDs of the request
|
|
||||||
- roi: a binary mask on top of input_tokens.
|
|
||||||
- Purpose of roi: Since KV cache may only be available for a subset of
|
|
||||||
tokens in the input (for example, when vLLM is connected to an external
|
|
||||||
KV cache service), roi specifies the subset of tokens that the KV cache
|
|
||||||
is associated with.
|
|
||||||
- NOTE: roi can be further extended to describe which part of KV the
|
|
||||||
current process is holding (each process may only hold a part of KV
|
|
||||||
due to TP and PP). This is not implemented for now.
|
|
||||||
|
|
||||||
The value of the lookup buffer:
|
|
||||||
- key: the key tensor in the KV cache
|
|
||||||
- value: the value tensor in the KV cache
|
|
||||||
- hidden: the final hidden state generated by model forwarding. This allows
|
|
||||||
vLLM to bypass further model forwarding by transmitting the hidden state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def insert(
|
|
||||||
self,
|
|
||||||
input_tokens: torch.Tensor,
|
|
||||||
roi: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
hidden: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
"""Insert into the lookup buffer.
|
|
||||||
|
|
||||||
The functionality is similar to the following python statement
|
|
||||||
```
|
|
||||||
buffer[input_tokens, roi] = [key, value, hidden]
|
|
||||||
```
|
|
||||||
|
|
||||||
FIXME: in the future, we should only have two arguments, key and value,
|
|
||||||
where key is a tensor dict and value is a tensor dict.
|
|
||||||
|
|
||||||
FIXME: we should transmit both sampler outputs and the hidden states.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_tokens (torch.Tensor): token IDs.
|
|
||||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
|
||||||
key (torch.Tensor): The key tensor in the KV cache.
|
|
||||||
value (torch.Tensor): The value tensor in the KV cache.
|
|
||||||
hidden (torch.Tensor): The final hidden state tensor generated
|
|
||||||
during model forwarding to bypass model
|
|
||||||
forwarding.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: This method must be implemented in subclasses.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def drop_select(
|
|
||||||
self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None
|
|
||||||
) -> list[torch.Tensor | None]:
|
|
||||||
"""Select and *drop* KV cache entries from the lookup buffer.
|
|
||||||
|
|
||||||
The functionality is similar to the following python statements
|
|
||||||
```
|
|
||||||
ret = buffer.pop(input_tokens, roi)
|
|
||||||
return ret
|
|
||||||
```
|
|
||||||
|
|
||||||
If `input_tokens` and `roi` is `None`, it means selecting any of the
|
|
||||||
KV caches in the buffer, return, and remove it from the buffer, useful
|
|
||||||
when offloading KV cache to KV cache storage service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_tokens (torch.Tensor): token IDs.
|
|
||||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: This method must be implemented in subclasses.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class KVStoreBufferBase(KVCacheBufferBase):
|
|
||||||
"""
|
|
||||||
Abstract base class for a KVCache storage buffer with key-value semantics.
|
|
||||||
This class provides a simple key-value storage buffer abstract with basic
|
|
||||||
put/get operations, which enables flexible KVCache transfer granular
|
|
||||||
control.
|
|
||||||
|
|
||||||
The functionality is similar to a distributed key-value store, where:
|
|
||||||
- Key: A unique string identifier for the cached entry
|
|
||||||
- Value:
|
|
||||||
- Tensor to be stored and retrieved
|
|
||||||
- None (indicating deletion or empty value)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def put(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
value: torch.Tensor | None,
|
|
||||||
) -> None:
|
|
||||||
"""Store a key-value pair in the buffer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key (str): Unique identifier for a tensor, this tensor could be the
|
|
||||||
key cache tensor, value cache tensor, or hidden state tensor
|
|
||||||
generated during model forwarding.
|
|
||||||
|
|
||||||
value (Optional[torch.Tensor]): Tensor to be stored.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: This method must be implemented in subclasses.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
) -> torch.Tensor | None:
|
|
||||||
"""Retrieve a value from the buffer by key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key (str): Unique identifier for a tensor, this tensor could be the
|
|
||||||
key cache tensor, value cache tensor, or hidden state tensor
|
|
||||||
generated during model forwarding.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: This method must be implemented in subclasses.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
@ -1,164 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""
|
|
||||||
This file contains a new class `MooncakeStore` that allows developers to
|
|
||||||
think of KV cache transfer operations as putting new KV cache entries
|
|
||||||
into a remote KVStore-based lookup buffer and getting existing KV caches
|
|
||||||
from this remote lookup buffer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import load as safetensors_load
|
|
||||||
from safetensors.torch import save as safetensors_save
|
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVStoreBufferBase
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
|
|
||||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
|
||||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MooncakeStoreConfig:
|
|
||||||
local_hostname: str
|
|
||||||
metadata_server: str
|
|
||||||
global_segment_size: int
|
|
||||||
local_buffer_size: int
|
|
||||||
protocol: str
|
|
||||||
device_name: str
|
|
||||||
master_server_address: str
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
|
||||||
"""Load the config from a JSON file."""
|
|
||||||
with open(file_path) as fin:
|
|
||||||
config = json.load(fin)
|
|
||||||
return MooncakeStoreConfig(
|
|
||||||
local_hostname=config.get("local_hostname"),
|
|
||||||
metadata_server=config.get("metadata_server"),
|
|
||||||
global_segment_size=config.get(
|
|
||||||
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
|
||||||
),
|
|
||||||
local_buffer_size=config.get(
|
|
||||||
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
|
|
||||||
),
|
|
||||||
protocol=config.get("protocol", "tcp"),
|
|
||||||
device_name=config.get("device_name", ""),
|
|
||||||
master_server_address=config.get("master_server_address"),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_from_env() -> "MooncakeStoreConfig":
|
|
||||||
"""Load config from a file specified in the environment variable."""
|
|
||||||
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
|
||||||
if config_file_path is None:
|
|
||||||
raise ValueError(
|
|
||||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
|
||||||
)
|
|
||||||
return MooncakeStoreConfig.from_file(config_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
class MooncakeStore(KVStoreBufferBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: VllmConfig,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
from mooncake.store import MooncakeDistributedStore
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install mooncake by following the instructions at "
|
|
||||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
|
||||||
"to run vLLM with MooncakeConnector."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.store = MooncakeDistributedStore()
|
|
||||||
self.config = MooncakeStoreConfig.load_from_env()
|
|
||||||
logger.info("Mooncake Configuration loaded successfully.")
|
|
||||||
|
|
||||||
self.store.setup(
|
|
||||||
self.config.local_hostname,
|
|
||||||
self.config.metadata_server,
|
|
||||||
self.config.global_segment_size,
|
|
||||||
self.config.local_buffer_size,
|
|
||||||
self.config.protocol,
|
|
||||||
self.config.device_name,
|
|
||||||
self.config.master_server_address,
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error("Configuration loading failed: %s", e)
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("An error occurred while loading the configuration: %s", exc)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
# MooncakeDistributedStore will automatically call the destructor, so
|
|
||||||
# it is unnecessary to close it manually.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def put(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
value: torch.Tensor | None,
|
|
||||||
) -> None:
|
|
||||||
# A message queue needs to be introduced before making it asynchronous.
|
|
||||||
if value is not None:
|
|
||||||
self._put_impl(key, value)
|
|
||||||
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
) -> torch.Tensor | None:
|
|
||||||
# A message queue needs to be introduced before making it asynchronous.
|
|
||||||
value = self._get_impl(key)
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _put_impl(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
value: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
"""Put KVCache to Mooncake Store"""
|
|
||||||
device_id = value.device.index if value.device.type == "cuda" else -1
|
|
||||||
device_tensor = torch.tensor(device_id, dtype=torch.int32)
|
|
||||||
value_bytes = safetensors_save({"tensor": value, "device_id": device_tensor})
|
|
||||||
try:
|
|
||||||
self.store.put(key, value_bytes)
|
|
||||||
except TypeError as err:
|
|
||||||
logger.error("Failed to put value into Mooncake Store: %s", err)
|
|
||||||
raise TypeError("Mooncake Store Put Type Error.") from err
|
|
||||||
|
|
||||||
def _get_impl(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
) -> torch.Tensor | None:
|
|
||||||
"""Get KVCache from Mooncake Store"""
|
|
||||||
try:
|
|
||||||
data = self.store.get(key)
|
|
||||||
except TypeError as err:
|
|
||||||
logger.error("Failed to get value from Mooncake Store: %s", err)
|
|
||||||
raise TypeError("Mooncake Store Get Type Error.") from err
|
|
||||||
|
|
||||||
if data:
|
|
||||||
loaded_tensors = safetensors_load(data)
|
|
||||||
tensor = loaded_tensors["tensor"]
|
|
||||||
device_id_tensor = loaded_tensors["device_id"]
|
|
||||||
device_id = int(device_id_tensor.item())
|
|
||||||
device = (
|
|
||||||
torch.device("cuda", device_id)
|
|
||||||
if device_id >= 0
|
|
||||||
else torch.device("cpu")
|
|
||||||
)
|
|
||||||
return tensor.to(device)
|
|
||||||
|
|
||||||
return None
|
|
||||||
@ -1,242 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""
|
|
||||||
Implements a distributed key-value (KV) cache transfer mechanism.
|
|
||||||
|
|
||||||
Key Features:
|
|
||||||
- Distributed KV cache transmission using PyNccl pipes.
|
|
||||||
- Non-blocking `insert`, blocking `drop_select`.
|
|
||||||
- Use CPU signal pipe to avoid racing condition
|
|
||||||
- Handles buffer size constraints and provide backpressure mechanism to
|
|
||||||
stop the prefill instance when the decode instance is slow.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import threading
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVLookupBufferBase
|
|
||||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleBuffer(KVLookupBufferBase):
|
|
||||||
def __init__(
|
|
||||||
self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
signal_pipe: on CPU
|
|
||||||
|
|
||||||
NOTE: on-device recv will block all threads in the process, making the
|
|
||||||
KV cache producer unable to listen to new request while transmitting
|
|
||||||
KV cache. Luckily CPU recv only blocks the current thread so we use
|
|
||||||
CPU recv to listen to new request.
|
|
||||||
|
|
||||||
data_pipe: on device (e.g. GPU)
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.buffer: deque[list[torch.Tensor]] = deque()
|
|
||||||
|
|
||||||
self.buffer_size = 0
|
|
||||||
self.buffer_size_threshold = buffer_size_thresh
|
|
||||||
self.buffer_cv = threading.Condition()
|
|
||||||
self.signal_pipe = signal_pipe
|
|
||||||
self.data_pipe = data_pipe
|
|
||||||
self.request_handling_thread: threading.Thread | None = None
|
|
||||||
|
|
||||||
self.normal_signal = torch.tensor([0], device="cpu")
|
|
||||||
self.end_signal = None
|
|
||||||
|
|
||||||
def _matches(
|
|
||||||
self,
|
|
||||||
tokens_roi_sender: list[torch.Tensor],
|
|
||||||
tokens_roi_recver: list[torch.Tensor],
|
|
||||||
):
|
|
||||||
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
|
|
||||||
# tokens_roi_recver: tokens and roi of the consumer (query)
|
|
||||||
|
|
||||||
tokens_sender = tokens_roi_sender[0]
|
|
||||||
tokens_recver = tokens_roi_recver[0]
|
|
||||||
roi_sender = tokens_roi_sender[1]
|
|
||||||
roi_recver = tokens_roi_recver[1]
|
|
||||||
|
|
||||||
if tokens_recver is None:
|
|
||||||
# consumer sends an empty request
|
|
||||||
# semantics: DROP SELECT * LIMIT 1
|
|
||||||
# so any of the data in the buffer can be drop-selected
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Assuming that roi is a binary mask on tokens
|
|
||||||
tokens_sender = tokens_sender[roi_sender]
|
|
||||||
tokens_recver = tokens_recver[roi_recver]
|
|
||||||
|
|
||||||
# simple common prefix matching
|
|
||||||
min_length = min(len(tokens_sender), len(tokens_recver))
|
|
||||||
if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]):
|
|
||||||
return min_length
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def _send_tensor_and_dec_size(self, tensor: torch.Tensor | None) -> None:
|
|
||||||
assert tensor is not None, "Use self.data_pipe.send(None) instead"
|
|
||||||
self.buffer_size -= tensor.element_size() * tensor.numel()
|
|
||||||
if tensor.dtype == torch.bool:
|
|
||||||
tensor = tensor.float()
|
|
||||||
self.data_pipe.send_tensor(tensor)
|
|
||||||
|
|
||||||
def _get_element_size(self, data: list | torch.Tensor | None):
|
|
||||||
if isinstance(data, torch.Tensor):
|
|
||||||
return data.element_size() * data.numel()
|
|
||||||
if not data:
|
|
||||||
# cannot perform `not data` on a tensor
|
|
||||||
# so this check needs to go after the check above
|
|
||||||
return 0
|
|
||||||
|
|
||||||
raise AssertionError(f"Unknown data type {type(data)}")
|
|
||||||
|
|
||||||
def _add_to_buffer(
|
|
||||||
self,
|
|
||||||
input_tokens: torch.Tensor,
|
|
||||||
roi: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
hidden: torch.Tensor,
|
|
||||||
):
|
|
||||||
if isinstance(input_tokens, torch.Tensor):
|
|
||||||
input_tokens = input_tokens.clone()
|
|
||||||
if isinstance(roi, torch.Tensor):
|
|
||||||
roi = roi.clone()
|
|
||||||
if isinstance(key, torch.Tensor):
|
|
||||||
key = key.clone()
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
value = value.clone()
|
|
||||||
if isinstance(hidden, torch.Tensor):
|
|
||||||
hidden = hidden.clone()
|
|
||||||
|
|
||||||
buffer_item = [input_tokens, roi, key, value, hidden]
|
|
||||||
data_size = sum([self._get_element_size(data) for data in buffer_item])
|
|
||||||
|
|
||||||
with self.buffer_cv:
|
|
||||||
if self.buffer_size + data_size > self.buffer_size_threshold:
|
|
||||||
# log outside the while loop to avoid this message being logged
|
|
||||||
# repeatedly.
|
|
||||||
logger.debug("KV transfer buffer is full. Handling...")
|
|
||||||
while self.buffer_size + data_size > self.buffer_size_threshold:
|
|
||||||
self.buffer_cv.wait()
|
|
||||||
|
|
||||||
self.buffer_size += data_size
|
|
||||||
self.buffer.append(buffer_item)
|
|
||||||
self.buffer_cv.notify()
|
|
||||||
|
|
||||||
def _is_end_signal(self, signal):
|
|
||||||
return signal is None
|
|
||||||
|
|
||||||
def drop_select_handler(self):
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
signal = self.signal_pipe.recv_tensor()
|
|
||||||
if self._is_end_signal(signal):
|
|
||||||
logger.info("Received end signal!")
|
|
||||||
break
|
|
||||||
|
|
||||||
input_tokens = self.data_pipe.recv_tensor()
|
|
||||||
|
|
||||||
roi = self.data_pipe.recv_tensor()
|
|
||||||
assert roi is not None, (
|
|
||||||
"Please provide the roi when sending drop-select request"
|
|
||||||
)
|
|
||||||
roi = roi > 0.5
|
|
||||||
tokens_roi_recver = [input_tokens, roi]
|
|
||||||
|
|
||||||
def is_buffer_available(
|
|
||||||
tokens_roi_recver: list[torch.Tensor],
|
|
||||||
) -> bool:
|
|
||||||
# perform input tokens and roi matching
|
|
||||||
# FIXME: this matching is O(n), ideally it should be O(1)
|
|
||||||
# but this buffer size won't (and shouldn't) be too large so
|
|
||||||
# the fix is not urgent.
|
|
||||||
for _ in range(len(self.buffer)):
|
|
||||||
if self._matches(self.buffer[0], tokens_roi_recver) > 0:
|
|
||||||
return True
|
|
||||||
# rotate the element we just accessed to the end
|
|
||||||
self.buffer.rotate(-1)
|
|
||||||
return False
|
|
||||||
|
|
||||||
with self.buffer_cv:
|
|
||||||
while not is_buffer_available(tokens_roi_recver):
|
|
||||||
logger.debug("KV transfer buffer is not available. Waiting...")
|
|
||||||
self.buffer_cv.wait()
|
|
||||||
# need to clone the tensor
|
|
||||||
# in case the tensor is freed before sending finishes
|
|
||||||
matched_item = self.buffer.popleft()
|
|
||||||
for tensor in matched_item:
|
|
||||||
self._send_tensor_and_dec_size(tensor)
|
|
||||||
self.buffer_cv.notify()
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
if "Connection closed by peer" not in str(e):
|
|
||||||
raise e
|
|
||||||
|
|
||||||
logger.debug("Closing drop_select_handler")
|
|
||||||
|
|
||||||
def drop_select(
|
|
||||||
self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None
|
|
||||||
) -> list[torch.Tensor | None]:
|
|
||||||
assert self.request_handling_thread is None, (
|
|
||||||
"drop_select should be called by the KV cache consumer "
|
|
||||||
"(e.g. the decode vLLM instance)"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(input_tokens, torch.Tensor):
|
|
||||||
input_tokens = input_tokens.clone()
|
|
||||||
if isinstance(roi, torch.Tensor):
|
|
||||||
roi = roi.clone().float()
|
|
||||||
|
|
||||||
self.signal_pipe.send_tensor(self.normal_signal)
|
|
||||||
self.data_pipe.send_tensor(input_tokens)
|
|
||||||
self.data_pipe.send_tensor(roi)
|
|
||||||
|
|
||||||
input_tokens = self.data_pipe.recv_tensor()
|
|
||||||
roi = self.data_pipe.recv_tensor()
|
|
||||||
if roi is not None:
|
|
||||||
# convert from float tensor to bool tensor
|
|
||||||
# as PyNccl does not support sending bool tensor
|
|
||||||
roi = roi > 0.5
|
|
||||||
key = self.data_pipe.recv_tensor()
|
|
||||||
value = self.data_pipe.recv_tensor()
|
|
||||||
hidden = self.data_pipe.recv_tensor()
|
|
||||||
|
|
||||||
return [input_tokens, roi, key, value, hidden]
|
|
||||||
|
|
||||||
def insert(
|
|
||||||
self,
|
|
||||||
input_tokens: torch.Tensor,
|
|
||||||
roi: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
hidden: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
self._add_to_buffer(input_tokens, roi, key, value, hidden)
|
|
||||||
|
|
||||||
# when calling the insert, the current process is a sender
|
|
||||||
# need to launch the request handler and start listening to request.
|
|
||||||
if self.request_handling_thread is None:
|
|
||||||
self.request_handling_thread = threading.Thread(
|
|
||||||
target=self.drop_select_handler
|
|
||||||
)
|
|
||||||
self.request_handling_thread.start()
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
if (
|
|
||||||
hasattr(self, "request_handling_thread")
|
|
||||||
and self.request_handling_thread is not None
|
|
||||||
):
|
|
||||||
self.request_handling_thread.join()
|
|
||||||
|
|
||||||
else:
|
|
||||||
# TODO: have a explicit close signal and have a explicit way to
|
|
||||||
# check if it's requester
|
|
||||||
self.signal_pipe.send_tensor(self.end_signal)
|
|
||||||
@ -1,66 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""
|
|
||||||
This file defines an interface `KVPipeBase`
|
|
||||||
that provides an abstraction for sending and receiving tensors, or None, via
|
|
||||||
distributed communications.
|
|
||||||
|
|
||||||
All classes instantiated from this interface are assumed to be a FIFO pipe.
|
|
||||||
|
|
||||||
If your distributed communication platform already supports key-value lookup,
|
|
||||||
you can bypass this interface and directly start from `kv_lookup_buffer`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class KVPipeBase(ABC):
|
|
||||||
"""
|
|
||||||
This class provides an interface for sending and receiving tensors, or
|
|
||||||
None, by distributed communications.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def send_tensor(self, tensor: torch.Tensor | None) -> None:
|
|
||||||
"""Send a tensor, or None, via the pipe.
|
|
||||||
|
|
||||||
Need to support sending None -- important for error handling.
|
|
||||||
|
|
||||||
TODO: add a `key` argument so that we can use traditional
|
|
||||||
key-value database as the distributed communication mechanism behind
|
|
||||||
the pipe.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: This method must be implemented in subclasses.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def recv_tensor(self) -> torch.Tensor | None:
|
|
||||||
"""Receive a tensor (can be None) from the pipeline.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[torch.Tensor]: The tensor received from the pipeline. Can
|
|
||||||
be None.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: This method must be implemented in subclasses.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Close the pipeline and release resources.
|
|
||||||
|
|
||||||
This method is responsible for closing the communication pipeline
|
|
||||||
and releasing any resources associated with it.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: This method must be implemented in subclasses.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
@ -1,295 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import struct
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import zmq
|
|
||||||
from safetensors.torch import load as safetensors_load
|
|
||||||
from safetensors.torch import save as safetensors_save
|
|
||||||
|
|
||||||
from vllm.config.kv_transfer import KVTransferConfig
|
|
||||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.utils.network_utils import join_host_port, make_zmq_path, split_host_port
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
NONE_INT = -150886311
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MooncakeTransferEngineConfig:
|
|
||||||
prefill_url: str
|
|
||||||
decode_url: str
|
|
||||||
metadata_backend: str | None
|
|
||||||
metadata_server: str
|
|
||||||
protocol: str
|
|
||||||
device_name: str
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
|
|
||||||
"""Load the config from a JSON file."""
|
|
||||||
with open(file_path) as fin:
|
|
||||||
config = json.load(fin)
|
|
||||||
return MooncakeTransferEngineConfig(
|
|
||||||
prefill_url=config.get("prefill_url"),
|
|
||||||
decode_url=config.get("decode_url"),
|
|
||||||
metadata_backend=config.get("metadata_backend", None),
|
|
||||||
metadata_server=config.get("metadata_server"),
|
|
||||||
protocol=config.get("protocol", "tcp"),
|
|
||||||
device_name=config.get("device_name", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_from_env() -> "MooncakeTransferEngineConfig":
|
|
||||||
"""Load config from a file specified in the environment variable."""
|
|
||||||
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
|
||||||
if config_file_path is None:
|
|
||||||
raise ValueError(
|
|
||||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
|
||||||
)
|
|
||||||
return MooncakeTransferEngineConfig.from_file(config_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
class MooncakeTransferEngine:
|
|
||||||
"""Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""
|
|
||||||
|
|
||||||
def __init__(self, kv_rank: int, local_rank: int):
|
|
||||||
try:
|
|
||||||
from mooncake.engine import TransferEngine
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install mooncake by following the instructions at "
|
|
||||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
|
||||||
"to run vLLM with MooncakeConnector."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
self.engine = TransferEngine()
|
|
||||||
self.local_rank = local_rank
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.config = MooncakeTransferEngineConfig.load_from_env()
|
|
||||||
logger.info("Mooncake Configuration loaded successfully.")
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(e)
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("An error occurred while loading the configuration: %s", exc)
|
|
||||||
raise
|
|
||||||
prefill_host, base_prefill_port = split_host_port(self.config.prefill_url)
|
|
||||||
decode_host, base_decode_port = split_host_port(self.config.decode_url)
|
|
||||||
|
|
||||||
# Avoid ports conflict when running prefill and decode on the same node
|
|
||||||
if prefill_host == decode_host and base_prefill_port == base_decode_port:
|
|
||||||
base_decode_port = base_decode_port + 100
|
|
||||||
|
|
||||||
prefill_port = base_prefill_port + self.local_rank
|
|
||||||
decode_port = base_decode_port + self.local_rank
|
|
||||||
self.prefill_url = join_host_port(prefill_host, prefill_port)
|
|
||||||
self.decode_url = join_host_port(decode_host, decode_port)
|
|
||||||
|
|
||||||
self.initialize(
|
|
||||||
self.prefill_url if kv_rank == 0 else self.decode_url,
|
|
||||||
self.config.metadata_server,
|
|
||||||
self.config.protocol,
|
|
||||||
self.config.device_name,
|
|
||||||
self.config.metadata_backend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.remote_url = self.decode_url if kv_rank == 0 else self.prefill_url
|
|
||||||
|
|
||||||
# Initialize ZeroMQ context and sockets
|
|
||||||
self.context = zmq.Context() # type: ignore[attr-defined]
|
|
||||||
self.sender_socket = self.context.socket(zmq.constants.PUSH)
|
|
||||||
self.receiver_socket = self.context.socket(zmq.constants.PULL)
|
|
||||||
self.sender_ack = self.context.socket(zmq.constants.PULL)
|
|
||||||
self.receiver_ack = self.context.socket(zmq.constants.PUSH)
|
|
||||||
|
|
||||||
self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
|
|
||||||
self._setup_metadata_sockets(
|
|
||||||
kv_rank, prefill_host, base_prefill_port, decode_host, base_decode_port
|
|
||||||
)
|
|
||||||
|
|
||||||
def _setup_metadata_sockets(
|
|
||||||
self, kv_rank: int, p_host: str, p_port: int, d_host: str, d_port: int
|
|
||||||
) -> None:
|
|
||||||
"""Set up ZeroMQ sockets for sending and receiving data."""
|
|
||||||
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
|
||||||
p_rank_offset = p_port + 8 + self.local_rank * 2
|
|
||||||
d_rank_offset = d_port + 8 + self.local_rank * 2
|
|
||||||
if kv_rank == 0:
|
|
||||||
self.sender_socket.bind(make_zmq_path("tcp", p_host, p_rank_offset + 1))
|
|
||||||
self.receiver_socket.connect(
|
|
||||||
make_zmq_path("tcp", d_host, d_rank_offset + 1)
|
|
||||||
)
|
|
||||||
self.sender_ack.connect(make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
|
||||||
self.receiver_ack.bind(make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
|
||||||
else:
|
|
||||||
self.receiver_socket.connect(
|
|
||||||
make_zmq_path("tcp", p_host, p_rank_offset + 1)
|
|
||||||
)
|
|
||||||
self.sender_socket.bind(make_zmq_path("tcp", d_host, d_rank_offset + 1))
|
|
||||||
self.receiver_ack.bind(make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
|
||||||
self.sender_ack.connect(make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
|
||||||
|
|
||||||
def initialize(
|
|
||||||
self,
|
|
||||||
local_hostname: str,
|
|
||||||
metadata_server: str,
|
|
||||||
protocol: str,
|
|
||||||
device_name: str,
|
|
||||||
metadata_backend: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the mooncake instance."""
|
|
||||||
if metadata_backend is None:
|
|
||||||
self.engine.initialize(
|
|
||||||
local_hostname, metadata_server, protocol, device_name
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
supported_backend = ["etcd", "redis"]
|
|
||||||
metadata_backend = metadata_backend.lower()
|
|
||||||
if metadata_backend not in supported_backend:
|
|
||||||
raise ValueError(
|
|
||||||
"Mooncake Configuration error. `metadata_backend`"
|
|
||||||
f" should be one of {supported_backend}."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.engine.initialize_ext(
|
|
||||||
local_hostname, metadata_server, protocol, device_name, metadata_backend
|
|
||||||
)
|
|
||||||
|
|
||||||
def allocate_managed_buffer(self, length: int) -> int:
|
|
||||||
"""Allocate a managed buffer of the specified length."""
|
|
||||||
ret = self.engine.allocate_managed_buffer(length)
|
|
||||||
if ret <= 0:
|
|
||||||
logger.error("Allocation Return Error")
|
|
||||||
raise Exception("Allocation Return Error")
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def free_managed_buffer(self, buffer: int, length: int) -> int:
|
|
||||||
"""Free a previously allocated managed buffer."""
|
|
||||||
return self.engine.free_managed_buffer(buffer, length)
|
|
||||||
|
|
||||||
def transfer_sync(self, buffer: int, peer_buffer_address: int, length: int) -> int:
|
|
||||||
"""Synchronously transfer data to the specified address."""
|
|
||||||
ret = self.engine.transfer_sync_read(
|
|
||||||
self.remote_url, buffer, peer_buffer_address, length
|
|
||||||
)
|
|
||||||
if ret < 0:
|
|
||||||
logger.error("Transfer Return Error")
|
|
||||||
raise Exception("Transfer Return Error")
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def write_bytes_to_buffer(self, buffer: int, user_data: bytes, length: int) -> int:
|
|
||||||
"""Write bytes to the allocated buffer."""
|
|
||||||
return self.engine.write_bytes_to_buffer(buffer, user_data, length)
|
|
||||||
|
|
||||||
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
|
|
||||||
"""Read bytes from the allocated buffer."""
|
|
||||||
return self.engine.read_bytes_from_buffer(buffer, length)
|
|
||||||
|
|
||||||
def wait_for_ack(self, src_ptr: int, length: int) -> None:
|
|
||||||
"""Asynchronously wait for ACK from the receiver."""
|
|
||||||
ack = self.sender_ack.recv()
|
|
||||||
if ack != b"ACK":
|
|
||||||
logger.error("Failed to receive ACK from the receiver")
|
|
||||||
|
|
||||||
self.free_managed_buffer(src_ptr, length)
|
|
||||||
|
|
||||||
def send_bytes(self, user_data: bytes) -> None:
|
|
||||||
"""Send bytes to the remote process."""
|
|
||||||
length = len(user_data)
|
|
||||||
src_ptr = self.allocate_managed_buffer(length)
|
|
||||||
self.write_bytes_to_buffer(src_ptr, user_data, length)
|
|
||||||
self.sender_socket.send_multipart(
|
|
||||||
[struct.pack("!Q", src_ptr), struct.pack("!Q", length)]
|
|
||||||
)
|
|
||||||
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
|
|
||||||
|
|
||||||
def recv_bytes(self) -> bytes:
|
|
||||||
"""Receive bytes from the remote process."""
|
|
||||||
data = self.receiver_socket.recv_multipart()
|
|
||||||
src_ptr = struct.unpack("!Q", data[0])[0]
|
|
||||||
length = struct.unpack("!Q", data[1])[0]
|
|
||||||
dst_ptr = self.allocate_managed_buffer(length)
|
|
||||||
self.transfer_sync(dst_ptr, src_ptr, length)
|
|
||||||
ret = self.read_bytes_from_buffer(dst_ptr, length)
|
|
||||||
|
|
||||||
# Buffer cleanup
|
|
||||||
self.receiver_ack.send(b"ACK")
|
|
||||||
self.free_managed_buffer(dst_ptr, length)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class MooncakePipe(KVPipeBase):
|
|
||||||
"""MooncakeTransferEngine based Pipe implementation."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, local_rank: int, config: KVTransferConfig, device: str | None = None
|
|
||||||
):
|
|
||||||
"""Initialize the mooncake pipe and set related parameters."""
|
|
||||||
self.config = config
|
|
||||||
self.local_rank = local_rank
|
|
||||||
self.kv_rank = self.config.kv_rank
|
|
||||||
assert self.kv_rank is not None
|
|
||||||
if device is None:
|
|
||||||
self.device = self._select_device(self.config.kv_buffer_device)
|
|
||||||
else:
|
|
||||||
self.device = self._select_device(device)
|
|
||||||
|
|
||||||
self.transfer_engine = MooncakeTransferEngine(self.kv_rank, self.local_rank)
|
|
||||||
self.transport_thread: ThreadPoolExecutor | None = None
|
|
||||||
self.none_tensor = torch.tensor([NONE_INT], device=self.device)
|
|
||||||
|
|
||||||
def _select_device(self, device: str) -> torch.device:
|
|
||||||
"""Select available device (CUDA or CPU)."""
|
|
||||||
logger.info("Selecting device: %s", device)
|
|
||||||
if device == "cuda":
|
|
||||||
return torch.device(f"cuda:{self.local_rank}")
|
|
||||||
else:
|
|
||||||
return torch.device("cpu")
|
|
||||||
|
|
||||||
def tensor_hash(self, tensor: torch.Tensor) -> int:
|
|
||||||
"""Calculate the hash value of the tensor."""
|
|
||||||
return hash(tensor.data_ptr())
|
|
||||||
|
|
||||||
def _send_impl(self, tensor: torch.Tensor) -> None:
|
|
||||||
"""Implement the tensor sending logic using safetensors."""
|
|
||||||
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
|
|
||||||
|
|
||||||
def _recv_impl(self) -> torch.Tensor:
|
|
||||||
"""Implement the tensor receiving logic using safetensors."""
|
|
||||||
data = self.transfer_engine.recv_bytes()
|
|
||||||
return safetensors_load(data)["tensor"].to(self.device)
|
|
||||||
|
|
||||||
def send_tensor(self, tensor: torch.Tensor | None) -> None:
|
|
||||||
"""Send tensor to the target process."""
|
|
||||||
if self.transport_thread is None:
|
|
||||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
|
||||||
tensor = tensor if tensor is not None else self.none_tensor
|
|
||||||
assert len(tensor.shape) > 0
|
|
||||||
self.transport_thread.submit(self._send_impl, tensor)
|
|
||||||
|
|
||||||
def recv_tensor(self) -> torch.Tensor | None:
|
|
||||||
"""Receive tensor from other processes."""
|
|
||||||
if self.transport_thread is None:
|
|
||||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
|
||||||
tensor = self.transport_thread.submit(self._recv_impl).result()
|
|
||||||
if tensor.numel() == 1 and tensor.item() == NONE_INT:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Cleanup logic when closing the pipe."""
|
|
||||||
self.transfer_engine.sender_socket.close()
|
|
||||||
self.transfer_engine.receiver_socket.close()
|
|
||||||
self.transfer_engine.sender_ack.close()
|
|
||||||
self.transfer_engine.receiver_ack.close()
|
|
||||||
self.transfer_engine.context.term() # Terminate the ZMQ context
|
|
||||||
logger.info("Closed the transfer engine and cleaned up resources.")
|
|
||||||
@ -1,285 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""
|
|
||||||
This module implements a PyNccl pipe for sending and receiving
|
|
||||||
Optional[torch.Tensor] between distributed ranks with advanced
|
|
||||||
communication features.
|
|
||||||
|
|
||||||
Key Features:
|
|
||||||
- Supports sending and receiving tensors with metadata
|
|
||||||
- Handles both CUDA and CPU device communications
|
|
||||||
- Implements a non-blocking tensor transfer mechanism
|
|
||||||
- Manages buffer size and provides backpressure control
|
|
||||||
- Supports distributed process groups with configurable parameters
|
|
||||||
"""
|
|
||||||
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from collections.abc import Callable
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.config.kv_transfer import KVTransferConfig
|
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
|
||||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
|
||||||
from vllm.distributed.utils import StatelessProcessGroup
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BrokenPipeException(Exception):
|
|
||||||
def __init__(self, message):
|
|
||||||
self.message = message
|
|
||||||
super().__init__(self.message)
|
|
||||||
|
|
||||||
|
|
||||||
Metadata = dict[str, torch.Tensor | None]
|
|
||||||
|
|
||||||
|
|
||||||
class PyNcclPipe(KVPipeBase):
|
|
||||||
METADATA_LENGTH = 16
|
|
||||||
MAX_TENSOR_DIMENSIONS = 14
|
|
||||||
METADATA_DTYPE = torch.int64
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
local_rank: int,
|
|
||||||
config: KVTransferConfig,
|
|
||||||
device: str | None = None,
|
|
||||||
port_offset: int = 0,
|
|
||||||
):
|
|
||||||
self.config = config
|
|
||||||
self.local_rank = local_rank
|
|
||||||
self.kv_rank = self.config.kv_rank
|
|
||||||
assert self.kv_rank is not None
|
|
||||||
self.kv_parallel_size = self.config.kv_parallel_size
|
|
||||||
if device is None:
|
|
||||||
self.device = self._select_device(self.config.kv_buffer_device)
|
|
||||||
else:
|
|
||||||
self.device = self._select_device(device)
|
|
||||||
|
|
||||||
# build distributed connection and send/recv implementation
|
|
||||||
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
|
|
||||||
self.group = StatelessProcessGroup.create(
|
|
||||||
host=self.config.kv_ip,
|
|
||||||
port=self.config.kv_port + port_offset,
|
|
||||||
rank=self.kv_rank,
|
|
||||||
world_size=self.kv_parallel_size,
|
|
||||||
store_timeout=store_timeout,
|
|
||||||
)
|
|
||||||
# add a barrier to make sure the connection is initiated properly
|
|
||||||
self.group.barrier()
|
|
||||||
impl = self._get_device_send_recv_impl(self.group)
|
|
||||||
self.device_send_func, self.device_recv_func = impl
|
|
||||||
# set target rank
|
|
||||||
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
|
|
||||||
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
|
|
||||||
|
|
||||||
# transportation-related variables
|
|
||||||
self.transport_thread: ThreadPoolExecutor | None = None
|
|
||||||
self.buffer_size = 0
|
|
||||||
self.buffer_size_lock = threading.Lock()
|
|
||||||
self.buffer_size_thresh = self.config.kv_buffer_size
|
|
||||||
|
|
||||||
def _get_device_send_recv_impl(
|
|
||||||
self, group: StatelessProcessGroup
|
|
||||||
) -> tuple[
|
|
||||||
Callable[[torch.Tensor, int], None], Callable[[torch.Tensor, int], None]
|
|
||||||
]:
|
|
||||||
send: Callable[[torch.Tensor, int], None]
|
|
||||||
recv: Callable[[torch.Tensor, int], None]
|
|
||||||
if self.device.type == "cuda":
|
|
||||||
# use PyNCCL for send / recv
|
|
||||||
comm = PyNcclCommunicator(group, device=self.local_rank)
|
|
||||||
comm.disabled = False
|
|
||||||
send, recv = comm.send, comm.recv # type: ignore
|
|
||||||
else:
|
|
||||||
# This send / recv implementation here is NOT intended to transfer
|
|
||||||
# KV caches (and should NOT be repurposed to transfer KV caches).
|
|
||||||
# Currently it is only used to transmit control-plane messages
|
|
||||||
# for PyNcclBuffer.
|
|
||||||
send = group.send_obj
|
|
||||||
|
|
||||||
def my_recv(x, src):
|
|
||||||
x[...] = group.recv_obj(src)
|
|
||||||
|
|
||||||
recv = my_recv
|
|
||||||
|
|
||||||
return send, recv
|
|
||||||
|
|
||||||
def _select_device(self, device: str):
|
|
||||||
logger.info("Selecting device: %s", device)
|
|
||||||
if device == "cuda":
|
|
||||||
return torch.device(f"cuda:{self.local_rank}")
|
|
||||||
else:
|
|
||||||
return torch.device("cpu")
|
|
||||||
|
|
||||||
def _make_metadata(self, tensor: torch.Tensor | None) -> Metadata:
|
|
||||||
"""
|
|
||||||
Create the metadata as a dictionary based on the input tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: The input tensor or None if no tensor is provided.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
metadata: A dictionary with the following keys:
|
|
||||||
- "dtype": The data type of the tensor or None.
|
|
||||||
- "shape": The shape of the tensor or None.
|
|
||||||
"""
|
|
||||||
if tensor is None:
|
|
||||||
return {"dtype": None, "shape": None}
|
|
||||||
else:
|
|
||||||
return {"dtype": tensor.dtype, "shape": tensor.shape}
|
|
||||||
|
|
||||||
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Create a buffer to receive the tensor based on the provided metadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metadata: A dictionary with keys "dtype" and "shape",
|
|
||||||
describing the tensor's data type and shape.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
buffer: A tensor of the specified type and shape,
|
|
||||||
allocated on `self.device`.
|
|
||||||
"""
|
|
||||||
return torch.empty(
|
|
||||||
metadata["shape"], dtype=metadata["dtype"], device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
def _send_metadata(self, metadata: Metadata):
|
|
||||||
"""
|
|
||||||
Send the metadata dictionary to the target rank.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metadata: A dictionary with keys "dtype" and "shape".
|
|
||||||
"""
|
|
||||||
self.group.send_obj(metadata, self.target_rank_for_send)
|
|
||||||
|
|
||||||
def _recv_metadata(self) -> Metadata:
|
|
||||||
"""
|
|
||||||
Receive the metadata dictionary from the target rank.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
metadata: A dictionary with keys "dtype" and "shape"
|
|
||||||
describing the tensor.
|
|
||||||
"""
|
|
||||||
return self.group.recv_obj(self.target_rank_for_recv)
|
|
||||||
|
|
||||||
def _send_impl(self, tensor: torch.Tensor | None) -> None:
|
|
||||||
"""
|
|
||||||
The actual implementation of sending the tensor and its metadata to the
|
|
||||||
target rank.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: The input tensor to be sent, or `None` if no tensor is
|
|
||||||
being sent.
|
|
||||||
"""
|
|
||||||
metadata = self._make_metadata(tensor)
|
|
||||||
self._send_metadata(metadata)
|
|
||||||
if tensor is not None:
|
|
||||||
self.device_send_func(tensor.to(self.device), self.target_rank_for_send)
|
|
||||||
|
|
||||||
def _recv_impl(self) -> torch.Tensor | None:
|
|
||||||
"""
|
|
||||||
The actual implementation of receiving a tensor and its metadata from
|
|
||||||
the target rank.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
buffer: The received tensor, or `None` if no tensor is received.
|
|
||||||
"""
|
|
||||||
metadata = self._recv_metadata()
|
|
||||||
if metadata["dtype"] is None:
|
|
||||||
return None
|
|
||||||
buffer = self._prepare_recv_buffer(metadata)
|
|
||||||
self.device_recv_func(buffer, self.target_rank_for_recv)
|
|
||||||
|
|
||||||
return buffer
|
|
||||||
|
|
||||||
def send_tensor_wrapper(
|
|
||||||
self, tensor: torch.Tensor | None, tensor_size: int
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Wrapper for _send_impl to handle exceptions and update buffer size.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self._send_impl(tensor)
|
|
||||||
|
|
||||||
with self.buffer_size_lock:
|
|
||||||
self.buffer_size -= tensor_size
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"[rank%d]: Exception when trying to send %s, msg: %s",
|
|
||||||
torch.distributed.get_rank(),
|
|
||||||
str(tensor),
|
|
||||||
str(e),
|
|
||||||
)
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
def block_if_full(self):
|
|
||||||
"""
|
|
||||||
Block the current thread if the buffer size is larger than the
|
|
||||||
threshold.
|
|
||||||
"""
|
|
||||||
while self.buffer_size > self.buffer_size_thresh:
|
|
||||||
logger.debug("KV cache transfer pipe is full. Waiting...")
|
|
||||||
time.sleep(0.05)
|
|
||||||
|
|
||||||
def send_tensor(self, tensor: torch.Tensor | None) -> None:
|
|
||||||
"""
|
|
||||||
Sends a tensor and its metadata to the destination rank in a
|
|
||||||
non-blocking way.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: The tensor to send, or `None` if no tensor is being sent.
|
|
||||||
"""
|
|
||||||
if self.transport_thread is None:
|
|
||||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
|
||||||
|
|
||||||
if tensor is not None:
|
|
||||||
tensor_size = tensor.element_size() * tensor.numel()
|
|
||||||
else:
|
|
||||||
tensor_size = 0
|
|
||||||
|
|
||||||
self.block_if_full()
|
|
||||||
|
|
||||||
with self.buffer_size_lock:
|
|
||||||
self.buffer_size += tensor_size
|
|
||||||
|
|
||||||
self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size)
|
|
||||||
|
|
||||||
def recv_tensor(self) -> torch.Tensor | None:
|
|
||||||
"""
|
|
||||||
Receives a tensor and its metadata from the source rank. Blocking call.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The received tensor, or `None` if no tensor is received.
|
|
||||||
"""
|
|
||||||
if self.transport_thread is None:
|
|
||||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
|
||||||
|
|
||||||
future = self.transport_thread.submit(self._recv_impl)
|
|
||||||
|
|
||||||
try:
|
|
||||||
tensor = future.result()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Encountering exception in KV receiving thread")
|
|
||||||
logger.error("%s", e)
|
|
||||||
logger.error("My device: %s", self.device)
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""
|
|
||||||
Close the pipe and release associated resources.
|
|
||||||
"""
|
|
||||||
if hasattr(self, "transport_thread") and self.transport_thread is not None:
|
|
||||||
self.transport_thread.shutdown()
|
|
||||||
Loading…
x
Reference in New Issue
Block a user