[V1][P/D]Enhance Performance and code readability for P2pNcclConnector (#20906)

Signed-off-by: Abatom <abzhonghua@gmail.com>
This commit is contained in:
Zhonghua Deng 2025-07-17 13:13:00 +08:00 committed by GitHub
parent 76b494444f
commit 8a4e5c5f3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 262 additions and 252 deletions

View File

@ -31,7 +31,7 @@ Each P/D instance periodically sends a heartbeat packet to the Proxy/Router (cur
## KV Cache Transfer Methods ## KV Cache Transfer Methods
There are three methods for KVcache transfer: PUT, GET, and PUT_ASYNC. These methods can be specified using the `--kv-transfer-config` and `kv_connector_extra_config` parameters, specifically through the `send_type` field. Both PUT and PUT_ASYNC involve the P instance actively sending KVcache to the D instance. The difference is that PUT is a synchronous transfer method that blocks the main process, while PUT_ASYNC is an asynchronous transfer method. PUT_ASYNC uses a dedicated thread for sending KVcache, which means it does not block the main process. In contrast, the GET method involves the P instance saving the KVcache to the memory buffer after computing the prefill. The D instance then actively retrieves the computed KVcache from the P instance once it has allocated space for the KVcache. There are three methods for KVCache transfer: PUT, GET, and PUT_ASYNC. These methods can be specified using the `--kv-transfer-config` and `kv_connector_extra_config` parameters, specifically through the `send_type` field. Both PUT and PUT_ASYNC involve the P instance actively sending KVCache to the D instance. The difference is that PUT is a synchronous transfer method that blocks the main process, while PUT_ASYNC is an asynchronous transfer method. PUT_ASYNC uses a dedicated thread for sending KVCache, which means it does not block the main process. In contrast, the GET method involves the P instance saving the KVCache to the memory buffer after computing the prefill. The D instance then actively retrieves the computed KVCache from the P instance once it has allocated space for the KVCache.
Experimental results have shown that the performance of these methods, from highest to lowest, is as follows: PUT_ASYNC → GET → PUT. Experimental results have shown that the performance of these methods, from highest to lowest, is as follows: PUT_ASYNC → GET → PUT.
@ -39,13 +39,13 @@ Experimental results have shown that the performance of these methods, from high
As long as the address of the counterpart is known, point-to-point KV cache transfer (using NCCL) can be performed, without being constrained by rank and world size. To support dynamic scaling (expansion and contraction) of instances with PD disaggregation. This means that adding or removing P/D instances does not require a full system restart. As long as the address of the counterpart is known, point-to-point KV cache transfer (using NCCL) can be performed, without being constrained by rank and world size. To support dynamic scaling (expansion and contraction) of instances with PD disaggregation. This means that adding or removing P/D instances does not require a full system restart.
Each P/D instance only needs to create a single `P2pNcclEngine` instance. This instance maintains a ZMQ Server, which runs a dedicated thread to listen on the `zmq_addr` address and receive control flow requests from other instances. These requests include requests to establish an NCCL connection and requests to send KVcache metadata (such as tensor shapes and data types). However, it does not actually transmit the KVcache data itself. Each P/D instance only needs to create a single `P2pNcclEngine` instance. This instance maintains a ZMQ Server, which runs a dedicated thread to listen on the `zmq_addr` address and receive control flow requests from other instances. These requests include requests to establish an NCCL connection and requests to send KVCache metadata (such as tensor shapes and data types). However, it does not actually transmit the KVCache data itself.
When a P instance and a D instance transmit KVcache for the first time, they need to establish a ZMQ connection and an NCCL group. For subsequent KVcache transmissions, this ZMQ connection and NCCL group are reused. The NCCL group consists of only two ranks, meaning the world size is equal to 2. This design is intended to support dynamic scaling, which means that adding or removing P/D instances does not require a full system restart. As long as the address of the counterpart is known, point-to-point KVcache transmission can be performed, without being restricted by rank or world size. When a P instance and a D instance transmit KVCache for the first time, they need to establish a ZMQ connection and an NCCL group. For subsequent KVCache transmissions, this ZMQ connection and NCCL group are reused. The NCCL group consists of only two ranks, meaning the world size is equal to 2. This design is intended to support dynamic scaling, which means that adding or removing P/D instances does not require a full system restart. As long as the address of the counterpart is known, point-to-point KVCache transmission can be performed, without being restricted by rank or world size.
## NCCL Group Topology ## NCCL Group Topology
Currently, only symmetric TP (Tensor Parallelism) methods are supported for KVcache transmission. Asymmetric TP and PP (Pipeline Parallelism) methods will be supported in the future. Figure 2 illustrates the 1P2D setup, where each instance has a TP (Tensor Parallelism) degree of 2. There are a total of 7 NCCL groups: three vLLM instances each have one NCCL group with TP=2. Additionally, the 0th GPU card of the P instance establishes an NCCL group with the 0th GPU card of each D instance. Similarly, the 1st GPU card of the P instance establishes an NCCL group with the 1st GPU card of each D instance. Currently, only symmetric TP (Tensor Parallelism) methods are supported for KVCache transmission. Asymmetric TP and PP (Pipeline Parallelism) methods will be supported in the future. Figure 2 illustrates the 1P2D setup, where each instance has a TP (Tensor Parallelism) degree of 2. There are a total of 7 NCCL groups: three vLLM instances each have one NCCL group with TP=2. Additionally, the 0th GPU card of the P instance establishes an NCCL group with the 0th GPU card of each D instance. Similarly, the 1st GPU card of the P instance establishes an NCCL group with the 1st GPU card of each D instance.
![image2](https://github.com/user-attachments/assets/837e61d6-365e-4cbf-8640-6dd7ab295b36) ![image2](https://github.com/user-attachments/assets/837e61d6-365e-4cbf-8640-6dd7ab295b36)
@ -53,32 +53,18 @@ Each NCCL group occupies a certain amount of GPU memory buffer for communication
## GPU Memory Buffer and Tensor Memory Pool ## GPU Memory Buffer and Tensor Memory Pool
The trade-off in the size of the memory buffer is as follows: For P instances, the memory buffer is not required in PUT and PUT_ASYNC modes, but it is necessary in GET mode. For D instances, a memory buffer is needed in all three modes. The memory buffer for D instances should not be too large. Similarly, for P instances in GET mode, the memory buffer should also not be too large. The memory buffer of D instances is used to temporarily store KVcache sent by P instances. If it is too large, it will reduce the KVcache space available for normal inference by D instances, thereby decreasing the inference batch size and ultimately leading to a reduction in output throughput. The size of the memory buffer is configured by the parameter `kv_buffer_size`, measured in bytes, and is typically set to 5%10% of the memory size. The trade-off in the size of the memory buffer is as follows: For P instances, the memory buffer is not required in PUT and PUT_ASYNC modes, but it is necessary in GET mode. For D instances, a memory buffer is needed in all three modes. The memory buffer for D instances should not be too large. Similarly, for P instances in GET mode, the memory buffer should also not be too large. The memory buffer of D instances is used to temporarily store KVCache sent by P instances. If it is too large, it will reduce the KVCache space available for normal inference by D instances, thereby decreasing the inference batch size and ultimately leading to a reduction in output throughput. The size of the memory buffer is configured by the parameter `kv_buffer_size`, measured in bytes, and is typically set to 5%10% of the memory size.
If the `--max-num-seqs` parameter for P instances is set to a large value, due to the large batch size, P instances will generate a large amount of KVcache simultaneously. This may exceed the capacity of the memory buffer of D instances, resulting in KVcache loss. Once KVcache is lost, D instances need to recompute Prefill, which is equivalent to performing Prefill twice. Consequently, the time-to-first-token (TTFT) will significantly increase, leading to degraded performance. If the `--max-num-seqs` parameter for P instances is set to a large value, due to the large batch size, P instances will generate a large amount of KVCache simultaneously. This may exceed the capacity of the memory buffer of D instances, resulting in KVCache loss. Once KVCache is lost, D instances need to recompute Prefill, which is equivalent to performing Prefill twice. Consequently, the time-to-first-token (TTFT) will significantly increase, leading to degraded performance.
To address the above issues, I have designed and developed a local Tensor memory pool for storing KVcache, inspired by the buddy system used in Linux memory modules. Since the memory is sufficiently large, typically in the TB range on servers, there is no need to consider prefix caching or using block-based designs to reuse memory, thereby saving space. When the memory buffer is insufficient, KVcache can be directly stored in the Tensor memory pool, and D instances can subsequently retrieve KVcache from it. The read and write speed is that of PCIe, with PCIe 4.0 having a speed of approximately 21 GB/s, which is usually faster than the Prefill speed. Otherwise, solutions like Mooncake and lmcache would not be necessary. The Tensor memory pool acts as a flood diversion area, typically unused except during sudden traffic surges. In the worst-case scenario, my solution performs no worse than the normal situation with a Cache store. To address the above issues, I have designed and developed a local Tensor memory pool for storing KVCache, inspired by the buddy system used in Linux memory modules. Since the memory is sufficiently large, typically in the TB range on servers, there is no need to consider prefix caching or using block-based designs to reuse memory, thereby saving space. When the memory buffer is insufficient, KVCache can be directly stored in the Tensor memory pool, and D instances can subsequently retrieve KVCache from it. The read and write speed is that of PCIe, with PCIe 4.0 having a speed of approximately 21 GB/s, which is usually faster than the Prefill speed. Otherwise, solutions like Mooncake and lmcache would not be necessary. The Tensor memory pool acts as a flood diversion area, typically unused except during sudden traffic surges. In the worst-case scenario, my solution performs no worse than the normal situation with a Cache store.
# Install vLLM # Install vLLM
??? console "Commands" ??? console "Commands"
```shell ```shell
# Enter the home directory or your working directory. pip install "vllm>=0.9.2"
cd /home
# Download the installation package, and I will update the commit-id in time. You can directly copy the command.
wget https://vllm-wheels.s3.us-west-2.amazonaws.com/9112b443a042d8d815880b8780633882ad32b183/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
# Download the code repository.
git clone -b xpyd-v1 https://github.com/Abatom/vllm.git
cd vllm
# Set the installation package path.
export VLLM_PRECOMPILED_WHEEL_LOCATION=/home/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
# installation
pip install -e . -v
``` ```
# Run xPyD # Run xPyD
@ -90,7 +76,7 @@ To address the above issues, I have designed and developed a local Tensor memory
- You may need to modify the `kv_buffer_size` and `port` in the following commands (if there is a conflict). - You may need to modify the `kv_buffer_size` and `port` in the following commands (if there is a conflict).
- `PUT_ASYNC` offers the best performance and should be prioritized. - `PUT_ASYNC` offers the best performance and should be prioritized.
- The `--port` must be consistent with the `http_port` in the `--kv-transfer-config`. - The `--port` must be consistent with the `http_port` in the `--kv-transfer-config`.
- The `disagg_prefill_proxy_xpyd.py` script will use port 10001 (for receiving client requests) and port 30001 (for receiving service discovery from P and D instances). - The `disagg_proxy_p2p_nccl_xpyd.py` script will use port 10001 (for receiving client requests) and port 30001 (for receiving service discovery from P and D instances).
- The node running the proxy must have `quart` installed. - The node running the proxy must have `quart` installed.
- Supports multiple nodes; you just need to modify the `proxy_ip` and `proxy_port` in `--kv-transfer-config`. - Supports multiple nodes; you just need to modify the `proxy_ip` and `proxy_port` in `--kv-transfer-config`.
- In the following examples, it is assumed that **the proxy's IP is 10.0.1.1**. - In the following examples, it is assumed that **the proxy's IP is 10.0.1.1**.
@ -100,8 +86,8 @@ To address the above issues, I have designed and developed a local Tensor memory
### Proxy (e.g. 10.0.1.1) ### Proxy (e.g. 10.0.1.1)
```shell ```shell
cd {your vllm directory}/examples/online_serving/disagg_xpyd/ cd {your vllm directory}/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/
python3 disagg_prefill_proxy_xpyd.py & python3 disagg_proxy_p2p_nccl_xpyd.py &
``` ```
### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1) ### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1)
@ -111,7 +97,7 @@ python3 disagg_prefill_proxy_xpyd.py &
```shell ```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \ VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 20005 \ --port 20001 \
--tensor-parallel-size 1 \ --tensor-parallel-size 1 \
--seed 1024 \ --seed 1024 \
--served-model-name base_model \ --served-model-name base_model \
@ -123,7 +109,7 @@ python3 disagg_prefill_proxy_xpyd.py &
--gpu-memory-utilization 0.9 \ --gpu-memory-utilization 0.9 \
--disable-log-request \ --disable-log-request \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20005","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20001"}}' > /var/vllm.log 2>&1 &
``` ```
### Decode1 (e.g. 10.0.1.3 or 10.0.1.1) ### Decode1 (e.g. 10.0.1.3 or 10.0.1.1)
@ -133,7 +119,7 @@ python3 disagg_prefill_proxy_xpyd.py &
```shell ```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \ VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 20009 \ --port 20002 \
--tensor-parallel-size 1 \ --tensor-parallel-size 1 \
--seed 1024 \ --seed 1024 \
--served-model-name base_model \ --served-model-name base_model \
@ -145,7 +131,7 @@ python3 disagg_prefill_proxy_xpyd.py &
--gpu-memory-utilization 0.7 \ --gpu-memory-utilization 0.7 \
--disable-log-request \ --disable-log-request \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20009","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20002"}}' > /var/vllm.log 2>&1 &
``` ```
### Decode2 (e.g. 10.0.1.4 or 10.0.1.1) ### Decode2 (e.g. 10.0.1.4 or 10.0.1.1)
@ -167,7 +153,7 @@ python3 disagg_prefill_proxy_xpyd.py &
--gpu-memory-utilization 0.7 \ --gpu-memory-utilization 0.7 \
--disable-log-request \ --disable-log-request \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003"}}' > /var/vllm.log 2>&1 &
``` ```
### Decode3 (e.g. 10.0.1.5 or 10.0.1.1) ### Decode3 (e.g. 10.0.1.5 or 10.0.1.1)
@ -177,7 +163,7 @@ python3 disagg_prefill_proxy_xpyd.py &
```shell ```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \ VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 20008 \ --port 20004 \
--tensor-parallel-size 1 \ --tensor-parallel-size 1 \
--seed 1024 \ --seed 1024 \
--served-model-name base_model \ --served-model-name base_model \
@ -189,7 +175,7 @@ python3 disagg_prefill_proxy_xpyd.py &
--gpu-memory-utilization 0.7 \ --gpu-memory-utilization 0.7 \
--disable-log-request \ --disable-log-request \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20008","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20004"}}' > /var/vllm.log 2>&1 &
``` ```
## Run 3P1D ## Run 3P1D
@ -197,8 +183,8 @@ python3 disagg_prefill_proxy_xpyd.py &
### Proxy (e.g. 10.0.1.1) ### Proxy (e.g. 10.0.1.1)
```shell ```shell
cd {your vllm directory}/examples/online_serving/disagg_xpyd/ cd {your vllm directory}/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/
python3 disagg_prefill_proxy_xpyd.py & python3 disagg_proxy_p2p_nccl_xpyd.py &
``` ```
### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1) ### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1)
@ -208,7 +194,7 @@ python3 disagg_prefill_proxy_xpyd.py &
```shell ```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \ VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 20005 \ --port 20001 \
--tensor-parallel-size 1 \ --tensor-parallel-size 1 \
--seed 1024 \ --seed 1024 \
--served-model-name base_model \ --served-model-name base_model \
@ -220,7 +206,7 @@ python3 disagg_prefill_proxy_xpyd.py &
--gpu-memory-utilization 0.9 \ --gpu-memory-utilization 0.9 \
--disable-log-request \ --disable-log-request \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20005","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20001"}}' > /var/vllm.log 2>&1 &
``` ```
### Prefill2 (e.g. 10.0.1.3 or 10.0.1.1) ### Prefill2 (e.g. 10.0.1.3 or 10.0.1.1)
@ -230,7 +216,7 @@ python3 disagg_prefill_proxy_xpyd.py &
```shell ```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \ VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 20009 \ --port 20002 \
--tensor-parallel-size 1 \ --tensor-parallel-size 1 \
--seed 1024 \ --seed 1024 \
--served-model-name base_model \ --served-model-name base_model \
@ -242,7 +228,7 @@ python3 disagg_prefill_proxy_xpyd.py &
--gpu-memory-utilization 0.9 \ --gpu-memory-utilization 0.9 \
--disable-log-request \ --disable-log-request \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20009","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20002"}}' > /var/vllm.log 2>&1 &
``` ```
### Prefill3 (e.g. 10.0.1.4 or 10.0.1.1) ### Prefill3 (e.g. 10.0.1.4 or 10.0.1.1)
@ -264,7 +250,7 @@ python3 disagg_prefill_proxy_xpyd.py &
--gpu-memory-utilization 0.9 \ --gpu-memory-utilization 0.9 \
--disable-log-request \ --disable-log-request \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003"}}' > /var/vllm.log 2>&1 &
``` ```
### Decode1 (e.g. 10.0.1.5 or 10.0.1.1) ### Decode1 (e.g. 10.0.1.5 or 10.0.1.1)
@ -274,7 +260,7 @@ python3 disagg_prefill_proxy_xpyd.py &
```shell ```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \ VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 20008 \ --port 20004 \
--tensor-parallel-size 1 \ --tensor-parallel-size 1 \
--seed 1024 \ --seed 1024 \
--served-model-name base_model \ --served-model-name base_model \
@ -286,7 +272,7 @@ python3 disagg_prefill_proxy_xpyd.py &
--gpu-memory-utilization 0.7 \ --gpu-memory-utilization 0.7 \
--disable-log-request \ --disable-log-request \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20008","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20004"}}' > /var/vllm.log 2>&1 &
``` ```
# Single request # Single request
@ -334,24 +320,6 @@ pgrep python | xargs kill -9 && pkill -f python
# Test data # Test data
## **Scenario 1**: 1K input & 1K output tokens, E2E P99 latency ~20s ## **Scenario**: 1K input & 200 output tokens, E2E P99 latency ~2s
- **1P5D (6×A800) vs vLLM (1×A800)**:
- Throughput ↑7.2% (1085 → 6979/6)
- ITL (P99) ↓81.3% (120ms → 22.9ms)
- TTFT (P99) ↑26.8% (175ms → 222ms)
- TPOT: No change
- **1P6D (7×A800) vs vLLM (1×A800)**: ![testdata](https://github.com/user-attachments/assets/cef0953b-4567-4bf9-b940-405b92a28eb1)
- Throughput ↑9.6% (1085 → 8329/7)
- ITL (P99) ↓81.0% (120ms → 22.7ms)
- TTFT (P99) ↑210% (175ms →543ms)
- TPOT: No change
## **Scenario 2**: 1K input & 200 output tokens, E2E P99 latency ~4s
- **1P1D (2×A800) vs vLLM (1×A800)**:
- Throughput ↑37.4% (537 → 1476/2)
- ITL (P99) ↓81.8% (127ms → 23.1ms)
- TTFT (P99) ↑41.8% (160ms → 227ms)
- TPOT: No change
![testdata](https://github.com/user-attachments/assets/f791bfc7-9f3d-4e5c-9171-a42f9f4da627)

View File

@ -4,7 +4,9 @@
import os import os
import socket import socket
import threading import threading
import time
import uuid import uuid
from typing import Any
import aiohttp import aiohttp
import msgpack import msgpack
@ -12,12 +14,25 @@ import zmq
from quart import Quart, make_response, request from quart import Quart, make_response, request
count = 0 count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address prefill_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)
decode_instances: dict[str, str] = {} # http_address: zmq_address decode_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)
prefill_cv = threading.Condition() prefill_cv = threading.Condition()
decode_cv = threading.Condition() decode_cv = threading.Condition()
DEFAULT_PING_SECONDS = 5
def _remove_oldest_instances(instances: dict[str, Any]) -> None:
oldest_key = next(iter(instances), None)
while oldest_key is not None:
value = instances[oldest_key]
if value[1] > time.time():
break
print(f"🔴Remove [HTTP:{oldest_key}, ZMQ:{value[0]}, stamp:{value[1]}]")
instances.pop(oldest_key, None)
oldest_key = next(iter(instances), None)
def _listen_for_register(poller, router_socket): def _listen_for_register(poller, router_socket):
while True: while True:
@ -31,12 +46,23 @@ def _listen_for_register(poller, router_socket):
global prefill_instances global prefill_instances
global prefill_cv global prefill_cv
with prefill_cv: with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"] node = prefill_instances.pop(data["http_address"], None)
prefill_instances[data["http_address"]] = (
data["zmq_address"],
time.time() + DEFAULT_PING_SECONDS,
)
_remove_oldest_instances(prefill_instances)
elif data["type"] == "D": elif data["type"] == "D":
global decode_instances global decode_instances
global decode_cv global decode_cv
with decode_cv: with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"] node = decode_instances.pop(data["http_address"], None)
decode_instances[data["http_address"]] = (
data["zmq_address"],
time.time() + DEFAULT_PING_SECONDS,
)
_remove_oldest_instances(decode_instances)
else: else:
print( print(
"Unexpected, Received message from %s, data: %s", "Unexpected, Received message from %s, data: %s",
@ -44,6 +70,9 @@ def _listen_for_register(poller, router_socket):
data, data,
) )
if node is None:
print(f"🔵Add [HTTP:{data['http_address']}, ZMQ:{data['zmq_address']}]")
def start_service_discovery(hostname, port): def start_service_discovery(hostname, port):
if not hostname: if not hostname:
@ -105,12 +134,14 @@ async def handle_request():
with prefill_cv: with prefill_cv:
prefill_list = list(prefill_instances.items()) prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)] prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
prefill_zmq_addr = prefill_zmq_addr[0]
global decode_instances global decode_instances
global decode_cv global decode_cv
with decode_cv: with decode_cv:
decode_list = list(decode_instances.items()) decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)] decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
decode_zmq_addr = decode_zmq_addr[0]
print( print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, " f"handle_request count: {count}, [HTTP:{prefill_addr}, "

View File

@ -13,7 +13,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine) P2pNcclEngine)
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
@ -238,32 +237,16 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata() connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata) assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True) ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank) remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) self.p2p_nccl_engine.send_tensor(
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, request_id + "#" + layer_name, kv_layer, remote_address,
kv_cache, remote_address) request.slot_mapping,
isinstance(attn_metadata, MLACommonMetadata))
def wait_for_save(self): def wait_for_save(self):
if self.is_producer: if self.is_producer:
@ -286,9 +269,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
forward_context: ForwardContext = get_forward_context() no_compile_layers = (
self._vllm_config.compilation_config.static_forward_context)
return self.p2p_nccl_engine.get_finished(finished_req_ids, return self.p2p_nccl_engine.get_finished(finished_req_ids,
forward_context) no_compile_layers)
# ============================== # ==============================
# Scheduler-side methods # Scheduler-side methods
@ -418,14 +402,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_ids=block_ids, block_ids=block_ids,
block_size=self._block_size) block_size=self._block_size)
# Requests loaded asynchronously are not in the scheduler_output.
# for request_id in self._requests_need_load:
# request, block_ids = self._requests_need_load[request_id]
# meta.add_request(request_id=request.request_id,
# token_ids=request.prompt_token_ids,
# block_ids=block_ids,
# block_size=self._block_size)
self._requests_need_load.clear() self._requests_need_load.clear()
return meta return meta

View File

@ -8,7 +8,8 @@ import time
import typing import typing
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional from dataclasses import dataclass
from typing import Any, Optional
import msgpack import msgpack
import torch import torch
@ -21,9 +22,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import
TensorMemoryPool) TensorMemoryPool)
from vllm.utils import current_stream, get_ip from vllm.utils import current_stream, get_ip
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32 DEFAULT_MEM_POOL_SIZE_GB = 32
@ -59,6 +57,15 @@ def set_p2p_nccl_context(num_channels: str):
os.environ.pop(var, None) os.environ.pop(var, None)
@dataclass
class SendQueueItem:
tensor_id: str
remote_address: str
tensor: torch.Tensor
slot_mapping: torch.Tensor
is_mla: bool
class P2pNcclEngine: class P2pNcclEngine:
def __init__(self, def __init__(self,
@ -112,24 +119,26 @@ class P2pNcclEngine:
self.send_stream = torch.cuda.Stream() self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream() self.recv_stream = torch.cuda.Stream()
mem_pool_size_gb = self.config.get_from_extra_config( mem_pool_size_gb = float(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB) self.config.get_from_extra_config("mem_pool_size_gb",
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) * DEFAULT_MEM_POOL_SIZE_GB))
1024**3) # GB self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb *
1024**3)) # GB
# The sending type includes tree mutually exclusive options: # The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC. # PUT, GET, PUT_ASYNC.
self.send_type = self.config.get_from_extra_config("send_type", "PUT") self.send_type = self.config.get_from_extra_config(
"send_type", "PUT_ASYNC")
if self.send_type == "GET": if self.send_type == "GET":
# tensor_id: torch.Tensor # tensor_id: torch.Tensor
self.send_store: dict[str, torch.Tensor] = {} self.send_store: dict[str, torch.Tensor] = {}
else: else:
# PUT or PUT_ASYNC # PUT or PUT_ASYNC
# tensor_id: torch.Tensor # tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque() self.send_queue: deque[SendQueueItem] = deque()
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async, self._send_thread = threading.Thread(target=self.send_async,
daemon=True) daemon=True)
self._send_thread.start() self._send_thread.start()
@ -146,13 +155,12 @@ class P2pNcclEngine:
"nccl_num_channels", "8") "nccl_num_channels", "8")
self._listener_thread = threading.Thread( self._listener_thread = threading.Thread(
target=self._listen_for_requests, daemon=True) target=self.listen_for_requests, daemon=True)
self._listener_thread.start() self._listener_thread.start()
self._ping_thread = None self._ping_thread = None
if port_offset == 0 and self.proxy_address != "": if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping, self._ping_thread = threading.Thread(target=self.ping, daemon=True)
daemon=True)
self._ping_thread.start() self._ping_thread.start()
logger.info( logger.info(
@ -162,7 +170,7 @@ class P2pNcclEngine:
self.http_address, self.zmq_address, self.proxy_address, self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold, self.nccl_num_channels) self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
def _create_connect(self, remote_address: typing.Optional[str] = None): def create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None assert remote_address is not None
if remote_address not in self.socks: if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER) sock = self.context.socket(zmq.DEALER)
@ -184,7 +192,7 @@ class P2pNcclEngine:
comm: ncclComm_t = self.nccl.ncclCommInitRank( comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank) 2, unique_id, rank)
self.comms[remote_address] = (comm, rank) self.comms[remote_address] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank:%s",
self.zmq_address, remote_address, rank) self.zmq_address, remote_address, rank)
return self.socks[remote_address], self.comms[remote_address] return self.socks[remote_address], self.comms[remote_address]
@ -194,44 +202,54 @@ class P2pNcclEngine:
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[str] = None,
slot_mapping: torch.Tensor = None,
is_mla: bool = False,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
with self.recv_store_cv: with self.recv_store_cv:
self.recv_store[tensor_id] = tensor self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify() self.recv_store_cv.notify()
return True return True
else:
if self.send_type == "PUT":
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor item = SendQueueItem(tensor_id=tensor_id,
self.buffer_size += tensor_size remote_address=remote_address,
logger.debug( tensor=tensor,
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " slot_mapping=slot_mapping,
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", is_mla=is_mla)
remote_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
if self.send_type == "PUT":
return self.send_sync(item)
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
# GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, self.buffer_size,
oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address,
tensor_id, tensor_size, tensor.shape, self.rank,
self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True return True
def recv_tensor( def recv_tensor(
@ -267,7 +285,7 @@ class P2pNcclEngine:
return None return None
if remote_address not in self.socks: if remote_address not in self.socks:
self._create_connect(remote_address) self.create_connect(remote_address)
sock = self.socks[remote_address] sock = self.socks[remote_address]
comm, rank = self.comms[remote_address] comm, rank = self.comms[remote_address]
@ -282,121 +300,121 @@ class P2pNcclEngine:
remote_address, tensor_id, data["ret"]) remote_address, tensor_id, data["ret"])
return None return None
tensor = torch.empty(data["shape"], with torch.cuda.stream(self.recv_stream):
dtype=getattr(torch, data["dtype"]), tensor = torch.empty(data["shape"],
device=self.device) dtype=getattr(torch, data["dtype"]),
device=self.device)
self._recv(comm, tensor, rank ^ 1, self.recv_stream) self.recv(comm, tensor, rank ^ 1, self.recv_stream)
return tensor return tensor
def _listen_for_requests(self): def listen_for_requests(self):
while True: while True:
socks = dict(self.poller.poll()) socks = dict(self.poller.poll())
if self.router_socket in socks: if self.router_socket not in socks:
remote_address, message = self.router_socket.recv_multipart() continue
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart(
[remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data, addr)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError: remote_address, message = self.router_socket.recv_multipart()
self.router_socket.send_multipart( data = msgpack.loads(message)
[remote_address, b"1"]) if data["cmd"] == "NEW":
tensor = None unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(),
rank)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart([remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning( logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " "🔴[PUT]Recv Tensor, Out Of Threshold, "
"data:%s", self.zmq_address, "%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data) remote_address.decode(), data, addr)
else:
self.buffer_size += tensor_size
with self.recv_store_cv: except torch.cuda.OutOfMemoryError:
self.recv_store[tensor_id] = tensor self.router_socket.send_multipart([remote_address, b"1"])
self._have_received_tensor_id(tensor_id) tensor = None
self.recv_store_cv.notify()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype":
str(tensor.dtype).replace("torch.", "")
}
# LRU
self.send_store[tensor_id] = tensor
self._have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart(
[remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self._send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream)
else:
logger.warning( logger.warning(
"🚧Unexpected, Received message from %s, data:%s", "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
remote_address, data) "data:%s", self.zmq_address, remote_address.decode(),
data)
def _have_sent_tensor_id(self, tensor_id: str): with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "")
}
# LRU
self.send_store[tensor_id] = tensor
self.have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart(
[remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self.send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream)
else:
logger.warning(
"🚧Unexpected, Received message from %s, data:%s",
remote_address, data)
def have_sent_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0] request_id = tensor_id.split('#')[0]
if request_id not in self.send_request_id_to_tensor_ids: if request_id not in self.send_request_id_to_tensor_ids:
self.send_request_id_to_tensor_ids[request_id] = set() self.send_request_id_to_tensor_ids[request_id] = set()
self.send_request_id_to_tensor_ids[request_id].add(tensor_id) self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
def _have_received_tensor_id(self, tensor_id: str): def have_received_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0] request_id = tensor_id.split('#')[0]
if request_id not in self.recv_request_id_to_tensor_ids: if request_id not in self.recv_request_id_to_tensor_ids:
self.recv_request_id_to_tensor_ids[request_id] = set() self.recv_request_id_to_tensor_ids[request_id] = set()
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id) self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
def _send_async(self): def send_async(self):
while True: while True:
with self.send_queue_cv: with self.send_queue_cv:
while not self.send_queue: while not self.send_queue:
self.send_queue_cv.wait() self.send_queue_cv.wait()
tensor_id, remote_address, tensor = self.send_queue.popleft() item = self.send_queue.popleft()
if not self.send_queue: if not self.send_queue:
self.send_queue_cv.notify() self.send_queue_cv.notify()
self._send_sync(tensor_id, tensor, remote_address) self.send_sync(item)
def wait_for_sent(self): def wait_for_sent(self):
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
@ -409,22 +427,21 @@ class P2pNcclEngine:
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank) " to be empty, rank:%d", duration * 1000, self.rank)
def _send_sync( def send_sync(self, item: SendQueueItem) -> bool:
self, if item.remote_address is None:
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
) -> bool:
if remote_address is None:
return False return False
if remote_address not in self.socks: if item.remote_address not in self.socks:
self._create_connect(remote_address) self.create_connect(item.remote_address)
sock = self.socks[remote_address] with self.send_stream:
comm, rank = self.comms[remote_address] tensor = self.extract_kv_from_layer(item.is_mla, item.tensor,
item.slot_mapping)
sock = self.socks[item.remote_address]
comm, rank = self.comms[item.remote_address]
data = { data = {
"cmd": "PUT", "cmd": "PUT",
"tensor_id": tensor_id, "tensor_id": item.tensor_id,
"shape": tensor.shape, "shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "") "dtype": str(tensor.dtype).replace("torch.", "")
} }
@ -435,20 +452,21 @@ class P2pNcclEngine:
logger.error( logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, remote_address, rank, data, tensor.shape, self.zmq_address, item.remote_address, rank, data,
tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3, tensor.element_size() * tensor.numel() / 1024**3,
response.decode()) response.decode())
return False return False
self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id) self.have_sent_tensor_id(item.tensor_id)
return True return True
def get_finished( def get_finished(
self, finished_req_ids: set[str], forward_context: "ForwardContext" self, finished_req_ids: set[str], no_compile_layers
) -> tuple[Optional[set[str]], Optional[set[str]]]: ) -> tuple[Optional[set[str]], Optional[set[str]]]:
""" """
Notifies worker-side connector ids of requests that have Notifies worker-side connector ids of requests that have
@ -463,7 +481,7 @@ class P2pNcclEngine:
# Clear the buffer upon request completion. # Clear the buffer upon request completion.
for request_id in finished_req_ids: for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers: for layer_name in no_compile_layers:
tensor_id = request_id + "#" + layer_name tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store: if tensor_id in self.recv_store:
with self.recv_store_cv: with self.recv_store_cv:
@ -472,7 +490,6 @@ class P2pNcclEngine:
request_id, None) request_id, None)
self.recv_request_id_to_tensor_ids.pop( self.recv_request_id_to_tensor_ids.pop(
request_id, None) request_id, None)
addr = 0
if isinstance(tensor, tuple): if isinstance(tensor, tuple):
addr, _, _ = tensor addr, _, _ = tensor
self.pool.free(addr) self.pool.free(addr)
@ -485,7 +502,7 @@ class P2pNcclEngine:
return finished_sending or None, finished_recving or None return finished_sending or None, finished_recving or None
def _ping(self): def ping(self):
sock = self.context.socket(zmq.DEALER) sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address) logger.debug("ping start, zmq_address:%s", self.zmq_address)
@ -499,7 +516,7 @@ class P2pNcclEngine:
sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
time.sleep(3) time.sleep(3)
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, ( assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}") f"but the input tensor is on {tensor.device}")
@ -512,7 +529,7 @@ class P2pNcclEngine:
comm, cudaStream_t(stream.cuda_stream)) comm, cudaStream_t(stream.cuda_stream))
stream.synchronize() stream.synchronize()
def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
assert tensor.device == self.device, ( assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}") f"but the input tensor is on {tensor.device}")
@ -531,3 +548,21 @@ class P2pNcclEngine:
self._send_thread.join() self._send_thread.join()
if self._ping_thread is not None: if self._ping_thread is not None:
self._ping_thread.join() self._ping_thread.join()
@staticmethod
def extract_kv_from_layer(
is_mla: bool,
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if is_mla:
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]