[BugFix][PD]: make example proxy usable with P2pNcclConnector (#26628)

Signed-off-by: PAN <1162953505@qq.com>
This commit is contained in:
Pan Li 2025-11-21 01:38:31 +08:00 committed by GitHub
parent 22924383e1
commit e5bfcb6a88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 167 additions and 97 deletions

View File

@ -5,11 +5,12 @@ import argparse
import asyncio import asyncio
import logging import logging
import os import os
import time
import uuid
from urllib.parse import urlparse
import aiohttp import aiohttp
from quart import Quart, Response, make_response, request from quart import Quart, Response, make_response, request
from rate_limiter import RateLimiter
from request_queue import RequestQueue
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -24,26 +25,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--timeout", "--timeout",
type=float, type=float,
default=300, default=6 * 60 * 60,
help="Timeout for backend service requests in seconds (default: 300)", help="Timeout for backend service requests in seconds (default: 21600)",
)
parser.add_argument(
"--max-concurrent",
type=int,
default=100,
help="Maximum concurrent requests to backend services (default: 100)",
)
parser.add_argument(
"--queue-size",
type=int,
default=500,
help="Maximum number of requests in the queue (default: 500)",
)
parser.add_argument(
"--rate-limit",
type=int,
default=40,
help="Maximum requests per second (default: 40)",
) )
parser.add_argument( parser.add_argument(
"--port", "--port",
@ -54,14 +37,32 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--prefill-url", "--prefill-url",
type=str, type=str,
default="http://localhost:8100/v1/completions", default="http://localhost:8100",
help="Prefill service endpoint URL", help="Prefill service base URL (protocol + host[:port])",
) )
parser.add_argument( parser.add_argument(
"--decode-url", "--decode-url",
type=str, type=str,
default="http://localhost:8200/v1/completions", default="http://localhost:8200",
help="Decode service endpoint URL", help="Decode service base URL (protocol + host[:port])",
)
parser.add_argument(
"--kv-host",
type=str,
default="localhost",
help="Hostname or IP used by KV transfer (default: localhost)",
)
parser.add_argument(
"--prefill-kv-port",
type=int,
default=14579,
help="Prefill KV port (default: 14579)",
)
parser.add_argument(
"--decode-kv-port",
type=int,
default=14580,
help="Decode KV port (default: 14580)",
) )
return parser.parse_args() return parser.parse_args()
@ -73,70 +74,129 @@ def main():
# Initialize configuration using command line parameters # Initialize configuration using command line parameters
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
MAX_CONCURRENT_REQUESTS = args.max_concurrent
REQUEST_QUEUE_SIZE = args.queue_size
RATE_LIMIT = args.rate_limit
PREFILL_SERVICE_URL = args.prefill_url PREFILL_SERVICE_URL = args.prefill_url
DECODE_SERVICE_URL = args.decode_url DECODE_SERVICE_URL = args.decode_url
PORT = args.port PORT = args.port
PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}"
DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}"
logger.info(
"Proxy resolved KV addresses -> prefill: %s, decode: %s",
PREFILL_KV_ADDR,
DECODE_KV_ADDR,
)
app = Quart(__name__) app = Quart(__name__)
# Initialize the rate limiter and request queue # Attach the configuration object to the application instance so helper
rate_limiter = RateLimiter(RATE_LIMIT) # coroutines can read the resolved backend URLs and timeouts without using
request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE) # globals.
# Attach the configuration object to the application instance
app.config.update( app.config.update(
{ {
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT, "AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
"rate_limiter": rate_limiter,
"request_queue": request_queue,
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL, "PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
"DECODE_SERVICE_URL": DECODE_SERVICE_URL, "DECODE_SERVICE_URL": DECODE_SERVICE_URL,
"PREFILL_KV_ADDR": PREFILL_KV_ADDR,
"DECODE_KV_ADDR": DECODE_KV_ADDR,
} }
) )
# Start queue processing on app startup def _normalize_base_url(url: str) -> str:
@app.before_serving """Remove any trailing slash so path joins behave predictably."""
async def startup(): return url.rstrip("/")
"""Start request processing task when app starts serving"""
asyncio.create_task(request_queue.process())
async def forward_request(url, data): def _get_host_port(url: str) -> str:
"""Forward request to backend service with rate limiting and error handling""" """Return the hostname:port portion for logging and KV headers."""
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} parsed = urlparse(url)
host = parsed.hostname or "localhost"
port = parsed.port
if port is None:
port = 80 if parsed.scheme == "http" else 443
return f"{host}:{port}"
# Use rate limiter as context manager PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL)
async with ( DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL)
rate_limiter, KV_TARGET = _get_host_port(DECODE_SERVICE_URL)
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
): def _build_headers(request_id: str) -> dict[str, str]:
try: """Construct the headers expected by vLLM's P2P disagg connector."""
async with session.post( headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET}
url=url, json=data, headers=headers api_key = os.environ.get("OPENAI_API_KEY")
) as response: if api_key:
if response.status == 200: headers["Authorization"] = f"Bearer {api_key}"
# Stream response chunks return headers
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes async def _run_prefill(
else: request_path: str,
# Handle backend service errors payload: dict,
error_text = await response.text() headers: dict[str, str],
logger.error( request_id: str,
"Backend service error: %s - %s", ):
response.status, url = f"{PREFILL_BASE}{request_path}"
error_text, start_ts = time.perf_counter()
) logger.info("[prefill] start request_id=%s url=%s", request_id, url)
yield b'{"error": "Backend service error"}' try:
except aiohttp.ClientError as e: async with (
# Handle connection errors aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
logger.error("Connection error to %s: %s", url, str(e)) session.post(url=url, json=payload, headers=headers) as resp,
yield b'{"error": "Service unavailable"}' ):
except asyncio.TimeoutError: if resp.status != 200:
# Handle timeout errors error_text = await resp.text()
logger.error("Timeout connecting to %s", url) raise RuntimeError(
yield b'{"error": "Service timeout"}' f"Prefill backend error {resp.status}: {error_text}"
)
await resp.read()
logger.info(
"[prefill] done request_id=%s status=%s elapsed=%.2fs",
request_id,
resp.status,
time.perf_counter() - start_ts,
)
except asyncio.TimeoutError as exc:
raise RuntimeError(f"Prefill service timeout at {url}") from exc
except aiohttp.ClientError as exc:
raise RuntimeError(f"Prefill service unavailable at {url}") from exc
async def _stream_decode(
request_path: str,
payload: dict,
headers: dict[str, str],
request_id: str,
):
url = f"{DECODE_BASE}{request_path}"
# Stream tokens from the decode service once the prefill stage has
# materialized KV caches on the target workers.
logger.info("[decode] start request_id=%s url=%s", request_id, url)
try:
async with (
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
session.post(url=url, json=payload, headers=headers) as resp,
):
if resp.status != 200:
error_text = await resp.text()
logger.error(
"Decode backend error %s - %s", resp.status, error_text
)
err_msg = (
'{"error": "Decode backend error ' + str(resp.status) + '"}'
)
yield err_msg.encode()
return
logger.info(
"[decode] streaming response request_id=%s status=%s",
request_id,
resp.status,
)
async for chunk_bytes in resp.content.iter_chunked(1024):
yield chunk_bytes
logger.info("[decode] finished streaming request_id=%s", request_id)
except asyncio.TimeoutError:
logger.error("Decode service timeout at %s", url)
yield b'{"error": "Decode service timeout"}'
except aiohttp.ClientError as exc:
logger.error("Decode service error at %s: %s", url, exc)
yield b'{"error": "Decode service unavailable"}'
async def process_request(): async def process_request():
"""Process a single request through prefill and decode stages""" """Process a single request through prefill and decode stages"""
@ -146,13 +206,27 @@ def main():
# Create prefill request (max_tokens=1) # Create prefill request (max_tokens=1)
prefill_request = original_request_data.copy() prefill_request = original_request_data.copy()
prefill_request["max_tokens"] = 1 prefill_request["max_tokens"] = 1
if "max_completion_tokens" in prefill_request:
prefill_request["max_completion_tokens"] = 1
# Execute prefill stage # Execute prefill stage
async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request): # The request id encodes both KV socket addresses so the backend can
continue # shuttle tensors directly via NCCL once the prefill response
# completes.
request_id = (
f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_"
f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}"
)
headers = _build_headers(request_id)
await _run_prefill(request.path, prefill_request, headers, request_id)
# Execute decode stage and stream response # Execute decode stage and stream response
generator = forward_request(DECODE_SERVICE_URL, original_request_data) # Pass the unmodified user request so the decode phase can continue
# sampling with the already-populated KV cache.
generator = _stream_decode(
request.path, original_request_data, headers, request_id
)
response = await make_response(generator) response = await make_response(generator)
response.timeout = None # Disable timeout for streaming response response.timeout = None # Disable timeout for streaming response
return response return response
@ -168,23 +242,10 @@ def main():
@app.route("/v1/completions", methods=["POST"]) @app.route("/v1/completions", methods=["POST"])
async def handle_request(): async def handle_request():
"""Handle incoming API requests with concurrency and rate limiting""" """Handle incoming API requests with concurrency and rate limiting"""
# Create task for request processing
task = asyncio.create_task(process_request())
# Enqueue request or reject if queue is full
if not await request_queue.enqueue(task):
return Response(
response=b'{"error": "Server busy, try again later"}',
status=503,
content_type="application/json",
)
try: try:
# Return the response from the processing task return await process_request()
return await task
except asyncio.CancelledError: except asyncio.CancelledError:
# Handle task cancellation (timeout or queue full) logger.warning("Request cancelled")
logger.warning("Request cancelled due to timeout or queue full")
return Response( return Response(
response=b'{"error": "Request cancelled"}', response=b'{"error": "Request cancelled"}',
status=503, status=503,

View File

@ -24,7 +24,14 @@ cleanup() {
exit 0 exit 0
} }
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
if [[ -z "${VLLM_HOST_IP:-}" ]]; then
export VLLM_HOST_IP=127.0.0.1
echo "Using default VLLM_HOST_IP=127.0.0.1 (override by exporting VLLM_HOST_IP before running this script)"
else
echo "Using provided VLLM_HOST_IP=${VLLM_HOST_IP}"
fi
# install quart first -- required for disagg prefill proxy serve # install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &> /dev/null; then if python3 -c "import quart" &> /dev/null; then
@ -38,7 +45,7 @@ fi
wait_for_server() { wait_for_server() {
local port=$1 local port=$1
timeout 1200 bash -c " timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do until curl -i localhost:${port}/v1/models > /dev/null; do
sleep 1 sleep 1
done" && return 0 || return 1 done" && return 0 || return 1
} }
@ -48,21 +55,23 @@ wait_for_server() {
# prefilling instance, which is the KV producer # prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \ CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
--host 0.0.0.0 \
--port 8100 \ --port 8100 \
--max-model-len 100 \ --max-model-len 100 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.8 \
--trust-remote-code \ --trust-remote-code \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":"1e9","kv_port":"14579","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8100","send_type":"PUT_ASYNC"}}' &
# decoding instance, which is the KV consumer # decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
--host 0.0.0.0 \
--port 8200 \ --port 8200 \
--max-model-len 100 \ --max-model-len 100 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.8 \
--trust-remote-code \ --trust-remote-code \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":"1e10","kv_port":"14580","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8200","send_type":"PUT_ASYNC"}}' &
# wait until prefill and decode instances are ready # wait until prefill and decode instances are ready
wait_for_server 8100 wait_for_server 8100