diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index 904f805349148..d072c03c440b2 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -5,11 +5,12 @@ import argparse import asyncio import logging import os +import time +import uuid +from urllib.parse import urlparse import aiohttp from quart import Quart, Response, make_response, request -from rate_limiter import RateLimiter -from request_queue import RequestQueue # Configure logging logging.basicConfig(level=logging.INFO) @@ -24,26 +25,8 @@ def parse_args(): parser.add_argument( "--timeout", type=float, - default=300, - help="Timeout for backend service requests in seconds (default: 300)", - ) - parser.add_argument( - "--max-concurrent", - type=int, - default=100, - help="Maximum concurrent requests to backend services (default: 100)", - ) - parser.add_argument( - "--queue-size", - type=int, - default=500, - help="Maximum number of requests in the queue (default: 500)", - ) - parser.add_argument( - "--rate-limit", - type=int, - default=40, - help="Maximum requests per second (default: 40)", + default=6 * 60 * 60, + help="Timeout for backend service requests in seconds (default: 21600)", ) parser.add_argument( "--port", @@ -54,14 +37,32 @@ def parse_args(): parser.add_argument( "--prefill-url", type=str, - default="http://localhost:8100/v1/completions", - help="Prefill service endpoint URL", + default="http://localhost:8100", + help="Prefill service base URL (protocol + host[:port])", ) parser.add_argument( "--decode-url", type=str, - default="http://localhost:8200/v1/completions", - help="Decode service endpoint URL", + default="http://localhost:8200", + help="Decode service base URL (protocol + host[:port])", + ) + parser.add_argument( + "--kv-host", + type=str, + default="localhost", + help="Hostname or IP used by KV transfer (default: localhost)", + ) + parser.add_argument( + "--prefill-kv-port", + type=int, + default=14579, + help="Prefill KV port (default: 14579)", + ) + parser.add_argument( + "--decode-kv-port", + type=int, + default=14580, + help="Decode KV port (default: 14580)", ) return parser.parse_args() @@ -73,70 +74,129 @@ def main(): # Initialize configuration using command line parameters AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout) - MAX_CONCURRENT_REQUESTS = args.max_concurrent - REQUEST_QUEUE_SIZE = args.queue_size - RATE_LIMIT = args.rate_limit PREFILL_SERVICE_URL = args.prefill_url DECODE_SERVICE_URL = args.decode_url PORT = args.port + PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}" + DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}" + + logger.info( + "Proxy resolved KV addresses -> prefill: %s, decode: %s", + PREFILL_KV_ADDR, + DECODE_KV_ADDR, + ) + app = Quart(__name__) - # Initialize the rate limiter and request queue - rate_limiter = RateLimiter(RATE_LIMIT) - request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE) - - # Attach the configuration object to the application instance + # Attach the configuration object to the application instance so helper + # coroutines can read the resolved backend URLs and timeouts without using + # globals. app.config.update( { "AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT, - "rate_limiter": rate_limiter, - "request_queue": request_queue, "PREFILL_SERVICE_URL": PREFILL_SERVICE_URL, "DECODE_SERVICE_URL": DECODE_SERVICE_URL, + "PREFILL_KV_ADDR": PREFILL_KV_ADDR, + "DECODE_KV_ADDR": DECODE_KV_ADDR, } ) - # Start queue processing on app startup - @app.before_serving - async def startup(): - """Start request processing task when app starts serving""" - asyncio.create_task(request_queue.process()) + def _normalize_base_url(url: str) -> str: + """Remove any trailing slash so path joins behave predictably.""" + return url.rstrip("/") - async def forward_request(url, data): - """Forward request to backend service with rate limiting and error handling""" - headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + def _get_host_port(url: str) -> str: + """Return the hostname:port portion for logging and KV headers.""" + parsed = urlparse(url) + host = parsed.hostname or "localhost" + port = parsed.port + if port is None: + port = 80 if parsed.scheme == "http" else 443 + return f"{host}:{port}" - # Use rate limiter as context manager - async with ( - rate_limiter, - aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, - ): - try: - async with session.post( - url=url, json=data, headers=headers - ) as response: - if response.status == 200: - # Stream response chunks - async for chunk_bytes in response.content.iter_chunked(1024): - yield chunk_bytes - else: - # Handle backend service errors - error_text = await response.text() - logger.error( - "Backend service error: %s - %s", - response.status, - error_text, - ) - yield b'{"error": "Backend service error"}' - except aiohttp.ClientError as e: - # Handle connection errors - logger.error("Connection error to %s: %s", url, str(e)) - yield b'{"error": "Service unavailable"}' - except asyncio.TimeoutError: - # Handle timeout errors - logger.error("Timeout connecting to %s", url) - yield b'{"error": "Service timeout"}' + PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL) + DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL) + KV_TARGET = _get_host_port(DECODE_SERVICE_URL) + + def _build_headers(request_id: str) -> dict[str, str]: + """Construct the headers expected by vLLM's P2P disagg connector.""" + headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET} + api_key = os.environ.get("OPENAI_API_KEY") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + async def _run_prefill( + request_path: str, + payload: dict, + headers: dict[str, str], + request_id: str, + ): + url = f"{PREFILL_BASE}{request_path}" + start_ts = time.perf_counter() + logger.info("[prefill] start request_id=%s url=%s", request_id, url) + try: + async with ( + aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, + session.post(url=url, json=payload, headers=headers) as resp, + ): + if resp.status != 200: + error_text = await resp.text() + raise RuntimeError( + f"Prefill backend error {resp.status}: {error_text}" + ) + await resp.read() + logger.info( + "[prefill] done request_id=%s status=%s elapsed=%.2fs", + request_id, + resp.status, + time.perf_counter() - start_ts, + ) + except asyncio.TimeoutError as exc: + raise RuntimeError(f"Prefill service timeout at {url}") from exc + except aiohttp.ClientError as exc: + raise RuntimeError(f"Prefill service unavailable at {url}") from exc + + async def _stream_decode( + request_path: str, + payload: dict, + headers: dict[str, str], + request_id: str, + ): + url = f"{DECODE_BASE}{request_path}" + # Stream tokens from the decode service once the prefill stage has + # materialized KV caches on the target workers. + logger.info("[decode] start request_id=%s url=%s", request_id, url) + try: + async with ( + aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, + session.post(url=url, json=payload, headers=headers) as resp, + ): + if resp.status != 200: + error_text = await resp.text() + logger.error( + "Decode backend error %s - %s", resp.status, error_text + ) + err_msg = ( + '{"error": "Decode backend error ' + str(resp.status) + '"}' + ) + yield err_msg.encode() + return + logger.info( + "[decode] streaming response request_id=%s status=%s", + request_id, + resp.status, + ) + async for chunk_bytes in resp.content.iter_chunked(1024): + yield chunk_bytes + logger.info("[decode] finished streaming request_id=%s", request_id) + except asyncio.TimeoutError: + logger.error("Decode service timeout at %s", url) + yield b'{"error": "Decode service timeout"}' + except aiohttp.ClientError as exc: + logger.error("Decode service error at %s: %s", url, exc) + yield b'{"error": "Decode service unavailable"}' async def process_request(): """Process a single request through prefill and decode stages""" @@ -146,13 +206,27 @@ def main(): # Create prefill request (max_tokens=1) prefill_request = original_request_data.copy() prefill_request["max_tokens"] = 1 + if "max_completion_tokens" in prefill_request: + prefill_request["max_completion_tokens"] = 1 # Execute prefill stage - async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request): - continue + # The request id encodes both KV socket addresses so the backend can + # shuttle tensors directly via NCCL once the prefill response + # completes. + request_id = ( + f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_" + f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}" + ) + + headers = _build_headers(request_id) + await _run_prefill(request.path, prefill_request, headers, request_id) # Execute decode stage and stream response - generator = forward_request(DECODE_SERVICE_URL, original_request_data) + # Pass the unmodified user request so the decode phase can continue + # sampling with the already-populated KV cache. + generator = _stream_decode( + request.path, original_request_data, headers, request_id + ) response = await make_response(generator) response.timeout = None # Disable timeout for streaming response return response @@ -168,23 +242,10 @@ def main(): @app.route("/v1/completions", methods=["POST"]) async def handle_request(): """Handle incoming API requests with concurrency and rate limiting""" - # Create task for request processing - task = asyncio.create_task(process_request()) - - # Enqueue request or reject if queue is full - if not await request_queue.enqueue(task): - return Response( - response=b'{"error": "Server busy, try again later"}', - status=503, - content_type="application/json", - ) - try: - # Return the response from the processing task - return await task + return await process_request() except asyncio.CancelledError: - # Handle task cancellation (timeout or queue full) - logger.warning("Request cancelled due to timeout or queue full") + logger.warning("Request cancelled") return Response( response=b'{"error": "Request cancelled"}', status=503, diff --git a/examples/online_serving/disaggregated_prefill.sh b/examples/online_serving/disaggregated_prefill.sh index d434e22b1ae88..cd2f2e44a4d69 100644 --- a/examples/online_serving/disaggregated_prefill.sh +++ b/examples/online_serving/disaggregated_prefill.sh @@ -24,7 +24,14 @@ cleanup() { exit 0 } -export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + +if [[ -z "${VLLM_HOST_IP:-}" ]]; then + export VLLM_HOST_IP=127.0.0.1 + echo "Using default VLLM_HOST_IP=127.0.0.1 (override by exporting VLLM_HOST_IP before running this script)" +else + echo "Using provided VLLM_HOST_IP=${VLLM_HOST_IP}" +fi + # install quart first -- required for disagg prefill proxy serve if python3 -c "import quart" &> /dev/null; then @@ -38,7 +45,7 @@ fi wait_for_server() { local port=$1 timeout 1200 bash -c " - until curl -s localhost:${port}/v1/completions > /dev/null; do + until curl -i localhost:${port}/v1/models > /dev/null; do sleep 1 done" && return 0 || return 1 } @@ -48,21 +55,23 @@ wait_for_server() { # prefilling instance, which is the KV producer CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \ + --host 0.0.0.0 \ --port 8100 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ --trust-remote-code \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":"1e9","kv_port":"14579","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8100","send_type":"PUT_ASYNC"}}' & -# decoding instance, which is the KV consumer +# decoding instance, which is the KV consumer CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ + --host 0.0.0.0 \ --port 8200 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ --trust-remote-code \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":"1e10","kv_port":"14580","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8200","send_type":"PUT_ASYNC"}}' & # wait until prefill and decode instances are ready wait_for_server 8100