change disagg_prefill example to use zmq

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark 2025-01-20 23:14:37 +08:00
parent 298298f97d
commit d6945ecdf0
2 changed files with 15 additions and 9 deletions

View File

@ -26,14 +26,6 @@ cleanup() {
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &> /dev/null; then
echo "Quart is already installed."
else
echo "Quart is not installed. Installing..."
python3 -m pip install quart
fi
# a function that waits vLLM server to start
wait_for_server() {
local port=$1
@ -49,6 +41,7 @@ wait_for_server() {
# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
--port 8100 \
--zmq-server-port 7010 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
@ -58,13 +51,25 @@ CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
--port 8200 \
--zmq-server-port 7011 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
# wait until prefill and decode instances are ready
# launch a proxy server that opens the service at port 8000
# the workflow of this proxy:
# - send the request to prefill vLLM instance (via zmq port 7010), change max_tokens
# to 1
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
# instance (via zmq port 7011)
vllm connect --port 8000 \
--prefill-addr 127.0.0.1:7010 \
--decode-addr 127.0.0.1:7011 &
# wait until prefill, decode instances and proxy are ready
wait_for_server 8000
wait_for_server 8100
wait_for_server 8200

View File

@ -47,6 +47,7 @@ async def lifespan(app: FastAPI):
logger.info("success create_socket_pool sockets_decode")
yield
## close zmq context
logger.info("shutdown disagg connector")
logger.info("term zmqctx")
app.state.zmqctx.destroy(linger=0)