mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 16:26:59 +08:00
[BugFix][PD]: make example proxy usable with P2pNcclConnector (#26628)
Signed-off-by: PAN <1162953505@qq.com>
This commit is contained in:
parent
22924383e1
commit
e5bfcb6a88
@ -5,11 +5,12 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from quart import Quart, Response, make_response, request
|
from quart import Quart, Response, make_response, request
|
||||||
from rate_limiter import RateLimiter
|
|
||||||
from request_queue import RequestQueue
|
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -24,26 +25,8 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--timeout",
|
"--timeout",
|
||||||
type=float,
|
type=float,
|
||||||
default=300,
|
default=6 * 60 * 60,
|
||||||
help="Timeout for backend service requests in seconds (default: 300)",
|
help="Timeout for backend service requests in seconds (default: 21600)",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-concurrent",
|
|
||||||
type=int,
|
|
||||||
default=100,
|
|
||||||
help="Maximum concurrent requests to backend services (default: 100)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--queue-size",
|
|
||||||
type=int,
|
|
||||||
default=500,
|
|
||||||
help="Maximum number of requests in the queue (default: 500)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--rate-limit",
|
|
||||||
type=int,
|
|
||||||
default=40,
|
|
||||||
help="Maximum requests per second (default: 40)",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port",
|
"--port",
|
||||||
@ -54,14 +37,32 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prefill-url",
|
"--prefill-url",
|
||||||
type=str,
|
type=str,
|
||||||
default="http://localhost:8100/v1/completions",
|
default="http://localhost:8100",
|
||||||
help="Prefill service endpoint URL",
|
help="Prefill service base URL (protocol + host[:port])",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decode-url",
|
"--decode-url",
|
||||||
type=str,
|
type=str,
|
||||||
default="http://localhost:8200/v1/completions",
|
default="http://localhost:8200",
|
||||||
help="Decode service endpoint URL",
|
help="Decode service base URL (protocol + host[:port])",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-host",
|
||||||
|
type=str,
|
||||||
|
default="localhost",
|
||||||
|
help="Hostname or IP used by KV transfer (default: localhost)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill-kv-port",
|
||||||
|
type=int,
|
||||||
|
default=14579,
|
||||||
|
help="Prefill KV port (default: 14579)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-kv-port",
|
||||||
|
type=int,
|
||||||
|
default=14580,
|
||||||
|
help="Decode KV port (default: 14580)",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
@ -73,70 +74,129 @@ def main():
|
|||||||
|
|
||||||
# Initialize configuration using command line parameters
|
# Initialize configuration using command line parameters
|
||||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
|
||||||
MAX_CONCURRENT_REQUESTS = args.max_concurrent
|
|
||||||
REQUEST_QUEUE_SIZE = args.queue_size
|
|
||||||
RATE_LIMIT = args.rate_limit
|
|
||||||
PREFILL_SERVICE_URL = args.prefill_url
|
PREFILL_SERVICE_URL = args.prefill_url
|
||||||
DECODE_SERVICE_URL = args.decode_url
|
DECODE_SERVICE_URL = args.decode_url
|
||||||
PORT = args.port
|
PORT = args.port
|
||||||
|
|
||||||
|
PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}"
|
||||||
|
DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Proxy resolved KV addresses -> prefill: %s, decode: %s",
|
||||||
|
PREFILL_KV_ADDR,
|
||||||
|
DECODE_KV_ADDR,
|
||||||
|
)
|
||||||
|
|
||||||
app = Quart(__name__)
|
app = Quart(__name__)
|
||||||
|
|
||||||
# Initialize the rate limiter and request queue
|
# Attach the configuration object to the application instance so helper
|
||||||
rate_limiter = RateLimiter(RATE_LIMIT)
|
# coroutines can read the resolved backend URLs and timeouts without using
|
||||||
request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE)
|
# globals.
|
||||||
|
|
||||||
# Attach the configuration object to the application instance
|
|
||||||
app.config.update(
|
app.config.update(
|
||||||
{
|
{
|
||||||
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
|
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
|
||||||
"rate_limiter": rate_limiter,
|
|
||||||
"request_queue": request_queue,
|
|
||||||
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
|
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
|
||||||
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
|
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
|
||||||
|
"PREFILL_KV_ADDR": PREFILL_KV_ADDR,
|
||||||
|
"DECODE_KV_ADDR": DECODE_KV_ADDR,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start queue processing on app startup
|
def _normalize_base_url(url: str) -> str:
|
||||||
@app.before_serving
|
"""Remove any trailing slash so path joins behave predictably."""
|
||||||
async def startup():
|
return url.rstrip("/")
|
||||||
"""Start request processing task when app starts serving"""
|
|
||||||
asyncio.create_task(request_queue.process())
|
|
||||||
|
|
||||||
async def forward_request(url, data):
|
def _get_host_port(url: str) -> str:
|
||||||
"""Forward request to backend service with rate limiting and error handling"""
|
"""Return the hostname:port portion for logging and KV headers."""
|
||||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
parsed = urlparse(url)
|
||||||
|
host = parsed.hostname or "localhost"
|
||||||
|
port = parsed.port
|
||||||
|
if port is None:
|
||||||
|
port = 80 if parsed.scheme == "http" else 443
|
||||||
|
return f"{host}:{port}"
|
||||||
|
|
||||||
# Use rate limiter as context manager
|
PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL)
|
||||||
async with (
|
DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL)
|
||||||
rate_limiter,
|
KV_TARGET = _get_host_port(DECODE_SERVICE_URL)
|
||||||
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
|
|
||||||
):
|
def _build_headers(request_id: str) -> dict[str, str]:
|
||||||
try:
|
"""Construct the headers expected by vLLM's P2P disagg connector."""
|
||||||
async with session.post(
|
headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET}
|
||||||
url=url, json=data, headers=headers
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
) as response:
|
if api_key:
|
||||||
if response.status == 200:
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
# Stream response chunks
|
return headers
|
||||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
|
||||||
yield chunk_bytes
|
async def _run_prefill(
|
||||||
else:
|
request_path: str,
|
||||||
# Handle backend service errors
|
payload: dict,
|
||||||
error_text = await response.text()
|
headers: dict[str, str],
|
||||||
logger.error(
|
request_id: str,
|
||||||
"Backend service error: %s - %s",
|
):
|
||||||
response.status,
|
url = f"{PREFILL_BASE}{request_path}"
|
||||||
error_text,
|
start_ts = time.perf_counter()
|
||||||
)
|
logger.info("[prefill] start request_id=%s url=%s", request_id, url)
|
||||||
yield b'{"error": "Backend service error"}'
|
try:
|
||||||
except aiohttp.ClientError as e:
|
async with (
|
||||||
# Handle connection errors
|
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
|
||||||
logger.error("Connection error to %s: %s", url, str(e))
|
session.post(url=url, json=payload, headers=headers) as resp,
|
||||||
yield b'{"error": "Service unavailable"}'
|
):
|
||||||
except asyncio.TimeoutError:
|
if resp.status != 200:
|
||||||
# Handle timeout errors
|
error_text = await resp.text()
|
||||||
logger.error("Timeout connecting to %s", url)
|
raise RuntimeError(
|
||||||
yield b'{"error": "Service timeout"}'
|
f"Prefill backend error {resp.status}: {error_text}"
|
||||||
|
)
|
||||||
|
await resp.read()
|
||||||
|
logger.info(
|
||||||
|
"[prefill] done request_id=%s status=%s elapsed=%.2fs",
|
||||||
|
request_id,
|
||||||
|
resp.status,
|
||||||
|
time.perf_counter() - start_ts,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as exc:
|
||||||
|
raise RuntimeError(f"Prefill service timeout at {url}") from exc
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
raise RuntimeError(f"Prefill service unavailable at {url}") from exc
|
||||||
|
|
||||||
|
async def _stream_decode(
|
||||||
|
request_path: str,
|
||||||
|
payload: dict,
|
||||||
|
headers: dict[str, str],
|
||||||
|
request_id: str,
|
||||||
|
):
|
||||||
|
url = f"{DECODE_BASE}{request_path}"
|
||||||
|
# Stream tokens from the decode service once the prefill stage has
|
||||||
|
# materialized KV caches on the target workers.
|
||||||
|
logger.info("[decode] start request_id=%s url=%s", request_id, url)
|
||||||
|
try:
|
||||||
|
async with (
|
||||||
|
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
|
||||||
|
session.post(url=url, json=payload, headers=headers) as resp,
|
||||||
|
):
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
logger.error(
|
||||||
|
"Decode backend error %s - %s", resp.status, error_text
|
||||||
|
)
|
||||||
|
err_msg = (
|
||||||
|
'{"error": "Decode backend error ' + str(resp.status) + '"}'
|
||||||
|
)
|
||||||
|
yield err_msg.encode()
|
||||||
|
return
|
||||||
|
logger.info(
|
||||||
|
"[decode] streaming response request_id=%s status=%s",
|
||||||
|
request_id,
|
||||||
|
resp.status,
|
||||||
|
)
|
||||||
|
async for chunk_bytes in resp.content.iter_chunked(1024):
|
||||||
|
yield chunk_bytes
|
||||||
|
logger.info("[decode] finished streaming request_id=%s", request_id)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Decode service timeout at %s", url)
|
||||||
|
yield b'{"error": "Decode service timeout"}'
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
logger.error("Decode service error at %s: %s", url, exc)
|
||||||
|
yield b'{"error": "Decode service unavailable"}'
|
||||||
|
|
||||||
async def process_request():
|
async def process_request():
|
||||||
"""Process a single request through prefill and decode stages"""
|
"""Process a single request through prefill and decode stages"""
|
||||||
@ -146,13 +206,27 @@ def main():
|
|||||||
# Create prefill request (max_tokens=1)
|
# Create prefill request (max_tokens=1)
|
||||||
prefill_request = original_request_data.copy()
|
prefill_request = original_request_data.copy()
|
||||||
prefill_request["max_tokens"] = 1
|
prefill_request["max_tokens"] = 1
|
||||||
|
if "max_completion_tokens" in prefill_request:
|
||||||
|
prefill_request["max_completion_tokens"] = 1
|
||||||
|
|
||||||
# Execute prefill stage
|
# Execute prefill stage
|
||||||
async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request):
|
# The request id encodes both KV socket addresses so the backend can
|
||||||
continue
|
# shuttle tensors directly via NCCL once the prefill response
|
||||||
|
# completes.
|
||||||
|
request_id = (
|
||||||
|
f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_"
|
||||||
|
f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}"
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = _build_headers(request_id)
|
||||||
|
await _run_prefill(request.path, prefill_request, headers, request_id)
|
||||||
|
|
||||||
# Execute decode stage and stream response
|
# Execute decode stage and stream response
|
||||||
generator = forward_request(DECODE_SERVICE_URL, original_request_data)
|
# Pass the unmodified user request so the decode phase can continue
|
||||||
|
# sampling with the already-populated KV cache.
|
||||||
|
generator = _stream_decode(
|
||||||
|
request.path, original_request_data, headers, request_id
|
||||||
|
)
|
||||||
response = await make_response(generator)
|
response = await make_response(generator)
|
||||||
response.timeout = None # Disable timeout for streaming response
|
response.timeout = None # Disable timeout for streaming response
|
||||||
return response
|
return response
|
||||||
@ -168,23 +242,10 @@ def main():
|
|||||||
@app.route("/v1/completions", methods=["POST"])
|
@app.route("/v1/completions", methods=["POST"])
|
||||||
async def handle_request():
|
async def handle_request():
|
||||||
"""Handle incoming API requests with concurrency and rate limiting"""
|
"""Handle incoming API requests with concurrency and rate limiting"""
|
||||||
# Create task for request processing
|
|
||||||
task = asyncio.create_task(process_request())
|
|
||||||
|
|
||||||
# Enqueue request or reject if queue is full
|
|
||||||
if not await request_queue.enqueue(task):
|
|
||||||
return Response(
|
|
||||||
response=b'{"error": "Server busy, try again later"}',
|
|
||||||
status=503,
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Return the response from the processing task
|
return await process_request()
|
||||||
return await task
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# Handle task cancellation (timeout or queue full)
|
logger.warning("Request cancelled")
|
||||||
logger.warning("Request cancelled due to timeout or queue full")
|
|
||||||
return Response(
|
return Response(
|
||||||
response=b'{"error": "Request cancelled"}',
|
response=b'{"error": "Request cancelled"}',
|
||||||
status=503,
|
status=503,
|
||||||
|
|||||||
@ -24,7 +24,14 @@ cleanup() {
|
|||||||
exit 0
|
exit 0
|
||||||
}
|
}
|
||||||
|
|
||||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
|
||||||
|
if [[ -z "${VLLM_HOST_IP:-}" ]]; then
|
||||||
|
export VLLM_HOST_IP=127.0.0.1
|
||||||
|
echo "Using default VLLM_HOST_IP=127.0.0.1 (override by exporting VLLM_HOST_IP before running this script)"
|
||||||
|
else
|
||||||
|
echo "Using provided VLLM_HOST_IP=${VLLM_HOST_IP}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
# install quart first -- required for disagg prefill proxy serve
|
# install quart first -- required for disagg prefill proxy serve
|
||||||
if python3 -c "import quart" &> /dev/null; then
|
if python3 -c "import quart" &> /dev/null; then
|
||||||
@ -38,7 +45,7 @@ fi
|
|||||||
wait_for_server() {
|
wait_for_server() {
|
||||||
local port=$1
|
local port=$1
|
||||||
timeout 1200 bash -c "
|
timeout 1200 bash -c "
|
||||||
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
until curl -i localhost:${port}/v1/models > /dev/null; do
|
||||||
sleep 1
|
sleep 1
|
||||||
done" && return 0 || return 1
|
done" && return 0 || return 1
|
||||||
}
|
}
|
||||||
@ -48,21 +55,23 @@ wait_for_server() {
|
|||||||
|
|
||||||
# prefilling instance, which is the KV producer
|
# prefilling instance, which is the KV producer
|
||||||
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
|
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
|
||||||
|
--host 0.0.0.0 \
|
||||||
--port 8100 \
|
--port 8100 \
|
||||||
--max-model-len 100 \
|
--max-model-len 100 \
|
||||||
--gpu-memory-utilization 0.8 \
|
--gpu-memory-utilization 0.8 \
|
||||||
--trust-remote-code \
|
--trust-remote-code \
|
||||||
--kv-transfer-config \
|
--kv-transfer-config \
|
||||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' &
|
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":"1e9","kv_port":"14579","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8100","send_type":"PUT_ASYNC"}}' &
|
||||||
|
|
||||||
# decoding instance, which is the KV consumer
|
# decoding instance, which is the KV consumer
|
||||||
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
|
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
|
||||||
|
--host 0.0.0.0 \
|
||||||
--port 8200 \
|
--port 8200 \
|
||||||
--max-model-len 100 \
|
--max-model-len 100 \
|
||||||
--gpu-memory-utilization 0.8 \
|
--gpu-memory-utilization 0.8 \
|
||||||
--trust-remote-code \
|
--trust-remote-code \
|
||||||
--kv-transfer-config \
|
--kv-transfer-config \
|
||||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
|
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":"1e10","kv_port":"14580","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8200","send_type":"PUT_ASYNC"}}' &
|
||||||
|
|
||||||
# wait until prefill and decode instances are ready
|
# wait until prefill and decode instances are ready
|
||||||
wait_for_server 8100
|
wait_for_server 8100
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user