mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 06:37:00 +08:00
parent
28d0396ff1
commit
66349c33a1
@ -3,7 +3,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from collections.abc import AsyncGenerator, Mapping
|
from collections.abc import AsyncGenerator, Mapping
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import zmq
|
import zmq
|
||||||
@ -11,7 +11,10 @@ import zmq.asyncio
|
|||||||
|
|
||||||
from vllm.config import DecodingConfig, ModelConfig
|
from vllm.config import DecodingConfig, ModelConfig
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
|
from vllm.disaggregated.protocol import (PDGenerationRequest,
|
||||||
|
PDGenerationResponse, PDRequestType,
|
||||||
|
PDResponseType)
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -30,14 +33,36 @@ logger = init_logger(__name__)
|
|||||||
DEFAULT_MAX_TOKENS = 32000
|
DEFAULT_MAX_TOKENS = 32000
|
||||||
|
|
||||||
|
|
||||||
class PDClient:
|
class PDController(EngineClient):
|
||||||
"""
|
"""
|
||||||
PDEngine:
|
Controller that schedules work on the PDWorkers.
|
||||||
Equiavlent of AsyncLLM for P/D. Assumes there is
|
|
||||||
a Prefill and Decode service already running.
|
|
||||||
|
|
||||||
* TODO: actually handle errors and failure.
|
Conforms for the EngineClient protocol so it can
|
||||||
* TODO: support more than just text input.
|
be wrapped with the OpenAI Server.
|
||||||
|
|
||||||
|
Two Phases:
|
||||||
|
* Send request to prefill worker, await ack.
|
||||||
|
* Send request to decode worker.
|
||||||
|
|
||||||
|
KVSync happens directly between Engines,
|
||||||
|
handled by vLLM KVCacheTransfer.
|
||||||
|
|
||||||
|
[ OpenAI Server ]
|
||||||
|
|
|
||||||
|
[ PDController ]
|
||||||
|
|
|
||||||
|
[ zmq ]
|
||||||
|
|
|
||||||
|
[ PDWorker ] [ PDWorker ]
|
||||||
|
|
|
||||||
|
[ Engine ] <---> [ Engine ]
|
||||||
|
|
||||||
|
After PR #12957, we will support xPyD, so we will
|
||||||
|
also need to implement a scheduler.
|
||||||
|
we will need to support multiple
|
||||||
|
|
||||||
|
* TODO: actually handle errors and failure.
|
||||||
|
* TODO: support the full API (logprobs, multimodal).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prefill_addr: str, decode_addr: str,
|
def __init__(self, prefill_addr: str, decode_addr: str,
|
||||||
@ -49,6 +74,8 @@ class PDClient:
|
|||||||
self.encoder = msgspec.msgpack.Encoder()
|
self.encoder = msgspec.msgpack.Encoder()
|
||||||
|
|
||||||
# ZMQ communication.
|
# ZMQ communication.
|
||||||
|
# TODO: once https://github.com/vllm-project/vllm/pull/12957
|
||||||
|
# lands, do service discovery to scale out workers.
|
||||||
self.ctx = zmq.asyncio.Context()
|
self.ctx = zmq.asyncio.Context()
|
||||||
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
|
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
|
||||||
self.to_decode.bind(f"{decode_addr}")
|
self.to_decode.bind(f"{decode_addr}")
|
||||||
@ -63,7 +90,7 @@ class PDClient:
|
|||||||
self.log_running: Optional[asyncio.Task] = None
|
self.log_running: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
# Dummy: needed for EngineClient Protocol.
|
# Dummy: needed for EngineClient Protocol.
|
||||||
# TODO: refactor EngineClient to avoid needing this.
|
# TODO: refactor OAI Server to avoid needing this.
|
||||||
self.model_config = ModelConfig(model=model_name,
|
self.model_config = ModelConfig(model=model_name,
|
||||||
tokenizer=model_name,
|
tokenizer=model_name,
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
@ -73,7 +100,7 @@ class PDClient:
|
|||||||
seed=42)
|
seed=42)
|
||||||
|
|
||||||
# Dummy: needed for EngineClient Protocol.
|
# Dummy: needed for EngineClient Protocol.
|
||||||
# TODO: refactor EngineClient to avoid needing this.
|
# TODO: refactor OAI Server to avoid needing this.
|
||||||
init_kwargs = dict(
|
init_kwargs = dict(
|
||||||
tokenizer_id=self.model_config.tokenizer,
|
tokenizer_id=self.model_config.tokenizer,
|
||||||
enable_lora=False,
|
enable_lora=False,
|
||||||
@ -109,7 +136,7 @@ class PDClient:
|
|||||||
Pull responses from Decode + Prefill engines and
|
Pull responses from Decode + Prefill engines and
|
||||||
distribute back to the generate() tasks.
|
distribute back to the generate() tasks.
|
||||||
"""
|
"""
|
||||||
decoder = msgspec.msgpack.Decoder(PDResponse)
|
decoder = msgspec.msgpack.Decoder(PDGenerationResponse)
|
||||||
|
|
||||||
socket: Optional[zmq.asyncio.Socket] = None
|
socket: Optional[zmq.asyncio.Socket] = None
|
||||||
try:
|
try:
|
||||||
@ -117,72 +144,78 @@ class PDClient:
|
|||||||
socket.bind(self.connector_addr)
|
socket.bind(self.connector_addr)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
reponse_bytes = await socket.recv()
|
res_type, res_data = await socket.recv_multipart()
|
||||||
response = decoder.decode(reponse_bytes)
|
if res_type == PDResponseType.FAILURE:
|
||||||
logger.debug("Got Response: %s", response.request_id)
|
raise Exception("Failure Response from PDWorker.")
|
||||||
self.queues[response.request_id].put_nowait(response)
|
elif res_type == PDResponseType.GENERATION:
|
||||||
except:
|
response = decoder.decode(res_data)
|
||||||
# TODO: actually handle failure and shutdown.
|
logger.debug("Got Response: %s", response.request_id)
|
||||||
raise
|
self.queues[response.request_id].put_nowait(response)
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown response type.")
|
||||||
|
except Exception as e:
|
||||||
|
# TODO: distinguish between fatal and non-fatal errors.
|
||||||
|
for _, q in self.queues.values():
|
||||||
|
q.put_nowait(e)
|
||||||
|
raise e
|
||||||
finally:
|
finally:
|
||||||
if socket is not None:
|
if socket is not None:
|
||||||
socket.close(linger=0)
|
socket.close(linger=0)
|
||||||
|
|
||||||
async def _prefill(
|
async def _run_prefill(
|
||||||
self,
|
self,
|
||||||
request: PDRequest,
|
request: PDGenerationRequest,
|
||||||
q: asyncio.Queue[PDResponse],
|
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
|
||||||
):
|
):
|
||||||
# Send request to the prefill instance.
|
# Send request to the prefill instance.
|
||||||
req_bytes = self.encoder.encode(request)
|
req_bytes = self.encoder.encode(request)
|
||||||
await self.to_prefill.send(req_bytes, copy=False)
|
msg = (PDRequestType.GENERATION, req_bytes)
|
||||||
|
await self.to_prefill.send_multipart(msg, copy=False)
|
||||||
|
|
||||||
# Wait for the prefill to be done.
|
# Wait for the prefill to be done.
|
||||||
response = await q.get()
|
response = await q.get()
|
||||||
assert response.request_id == request.request_id
|
if isinstance(response, Exception):
|
||||||
if not response.success:
|
raise response
|
||||||
# TODO: actual error handling and shutdown.
|
|
||||||
raise Exception("Failed Prefill Request.")
|
|
||||||
|
|
||||||
async def _decode(
|
async def _run_decode(
|
||||||
self,
|
self,
|
||||||
request: PDRequest,
|
request: PDGenerationRequest,
|
||||||
q: asyncio.Queue[PDResponse],
|
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
|
||||||
) -> AsyncGenerator[PDRequest]:
|
) -> AsyncGenerator[PDGenerationResponse]:
|
||||||
|
|
||||||
# Send request to the decode instance.
|
# Send request to the decode instance.
|
||||||
req_bytes = self.encoder.encode(request)
|
req_bytes = self.encoder.encode(request)
|
||||||
await self.to_decode.send(req_bytes, copy=False)
|
msg = (PDRequestType.GENERATION, req_bytes)
|
||||||
|
await self.to_decode.send_multipart(msg, copy=False)
|
||||||
|
|
||||||
# Iterate response queue and yield each response to caller..
|
# Iterate response queue and yield each response to caller.
|
||||||
finished = False
|
finished = False
|
||||||
while not finished:
|
while not finished:
|
||||||
response = await q.get()
|
response = await q.get()
|
||||||
if not response.success:
|
if isinstance(response, Exception):
|
||||||
# TODO: actual error handling and shutdown.
|
raise response
|
||||||
raise Exception("Failed Decode Request.")
|
|
||||||
finished = response.finish_reason is not None
|
finished = response.finish_reason is not None
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
def _to_request_output(
|
def _to_request_output(
|
||||||
self,
|
self,
|
||||||
pd_response: PDResponse,
|
response: PDGenerationResponse,
|
||||||
prompt_token_ids: list[int],
|
prompt_token_ids: list[int],
|
||||||
) -> RequestOutput:
|
) -> RequestOutput:
|
||||||
finished = pd_response.finish_reason is not None
|
finished = response.finish_reason is not None
|
||||||
return RequestOutput(
|
return RequestOutput(
|
||||||
request_id=pd_response.request_id,
|
request_id=response.request_id,
|
||||||
prompt=None,
|
prompt=None,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
prompt_logprobs=None,
|
prompt_logprobs=None,
|
||||||
outputs=[
|
outputs=[
|
||||||
CompletionOutput(index=0,
|
CompletionOutput(index=0,
|
||||||
text=pd_response.text,
|
text=response.text,
|
||||||
token_ids=pd_response.token_ids,
|
token_ids=response.token_ids,
|
||||||
cumulative_logprob=None,
|
cumulative_logprob=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
finish_reason=pd_response.finish_reason,
|
finish_reason=response.finish_reason,
|
||||||
stop_reason=pd_response.stop_reason)
|
stop_reason=response.stop_reason)
|
||||||
],
|
],
|
||||||
finished=finished,
|
finished=finished,
|
||||||
)
|
)
|
||||||
@ -223,24 +256,24 @@ class PDClient:
|
|||||||
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
||||||
|
|
||||||
# Queue to gather output from output_handler.
|
# Queue to gather output from output_handler.
|
||||||
q: asyncio.Queue[PDResponse] = asyncio.Queue()
|
q = asyncio.Queue()
|
||||||
self.queues[request_id] = q
|
self.queues[request_id] = q
|
||||||
|
|
||||||
# (1) Perform the Prefill.
|
# (1) Perform the Prefill.
|
||||||
original_max_tokens = sampling_params.max_tokens
|
original_max_tokens = sampling_params.max_tokens
|
||||||
prompt_token_ids = prompt["prompt_token_ids"]
|
prompt_token_ids = prompt["prompt_token_ids"]
|
||||||
request = PDRequest(request_id=request_id,
|
request = PDGenerationRequest(request_id=request_id,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
request.sampling_params.max_tokens = 1
|
request.sampling_params.max_tokens = 1
|
||||||
logger.debug("Sending Prefill: %s", request.request_id)
|
logger.debug("Sending Prefill: %s", request.request_id)
|
||||||
pd_response = await self._prefill(request, q)
|
pd_response = await self._run_prefill(request, q)
|
||||||
|
|
||||||
# (2) Perform the Decodes.
|
# (2) Perform the Decodes.
|
||||||
logger.debug("Sending Decode: %s", request.request_id)
|
logger.debug("Sending Decode: %s", request.request_id)
|
||||||
request.sampling_params.max_tokens = original_max_tokens
|
request.sampling_params.max_tokens = original_max_tokens
|
||||||
async for pd_response in self._decode(request, q):
|
async for pd_response in self._run_decode(request, q):
|
||||||
logger.debug("Got Decode: %s", request.request_id)
|
logger.debug("Got Response: %s", request.request_id)
|
||||||
yield self._to_request_output(pd_response, prompt_token_ids)
|
yield self._to_request_output(pd_response, prompt_token_ids)
|
||||||
|
|
||||||
async def beam_search(
|
async def beam_search(
|
||||||
|
|||||||
@ -65,9 +65,9 @@ class PDWorker:
|
|||||||
|
|
||||||
async def run_busy_loop(self):
|
async def run_busy_loop(self):
|
||||||
"""
|
"""
|
||||||
Main execution loop for the PDWorker:
|
main loop:
|
||||||
* 1) Wait for a request from the PDClient.
|
1) wait for a request from the PDClient
|
||||||
* 2) Handle the request.
|
2) handle the request
|
||||||
"""
|
"""
|
||||||
logger.info("PDWorker is ready To handle requests.")
|
logger.info("PDWorker is ready To handle requests.")
|
||||||
|
|
||||||
@ -76,10 +76,14 @@ class PDWorker:
|
|||||||
req_type, req_data = await self.from_client.recv_multipart()
|
req_type, req_data = await self.from_client.recv_multipart()
|
||||||
|
|
||||||
# 2) Handle the request.
|
# 2) Handle the request.
|
||||||
await self._handle_request(req_type, req_data.buffer)
|
await self._handle_request(req_type, req_data)
|
||||||
|
|
||||||
async def _handle_request(self, req_type: bytes, req_data: bytes):
|
async def _handle_request(self, req_type: bytes, req_data: bytes):
|
||||||
"""Parse the request type and call the appropriate handler."""
|
"""
|
||||||
|
request handler:
|
||||||
|
1) parse the request type
|
||||||
|
2) call the appropriate handler for the request type
|
||||||
|
"""
|
||||||
if req_type == PDRequestType.GENERATION:
|
if req_type == PDRequestType.GENERATION:
|
||||||
req = self.decode_generation(req_data)
|
req = self.decode_generation(req_data)
|
||||||
await self._generation_handler(req)
|
await self._generation_handler(req)
|
||||||
@ -88,13 +92,18 @@ class PDWorker:
|
|||||||
await self._abort_handler(req)
|
await self._abort_handler(req)
|
||||||
|
|
||||||
async def _generation_handler(self, req: PDGenerationRequest):
|
async def _generation_handler(self, req: PDGenerationRequest):
|
||||||
"""Launch generation in a background task."""
|
"""
|
||||||
|
Handle a PDGenerationRequest by launching a task.
|
||||||
|
"""
|
||||||
task = asyncio.create_task(self._generate(req))
|
task = asyncio.create_task(self._generate(req))
|
||||||
self.running_requests.add(task)
|
self.running_requests.add(task)
|
||||||
task.add_done_callback(self.running_requests.discard)
|
task.add_done_callback(self.running_requests.discard)
|
||||||
|
|
||||||
async def _abort_handler(self, req: PDGenerationRequest):
|
async def _abort_handler(self, req: PDGenerationRequest):
|
||||||
"""Abort the request in the engine."""
|
"""
|
||||||
|
Handle a PDAbortRequest by cancelling the running task.
|
||||||
|
The _generate coro aborts in the Engine.
|
||||||
|
"""
|
||||||
# Convert running_requests set() into a dict(), keyed
|
# Convert running_requests set() into a dict(), keyed
|
||||||
# by request_id. Cancel the task when an abort comes in.
|
# by request_id. Cancel the task when an abort comes in.
|
||||||
# Then update the _generate coroutine to handle a
|
# Then update the _generate coroutine to handle a
|
||||||
@ -103,11 +112,11 @@ class PDWorker:
|
|||||||
|
|
||||||
async def _generate(self, req: PDGenerationRequest):
|
async def _generate(self, req: PDGenerationRequest):
|
||||||
"""
|
"""
|
||||||
Handle a single PDRequest:
|
Handle a single PDGenerationRequest:
|
||||||
* 1) submit request to AsyncLLM
|
* 1) submit request to AsyncLLM
|
||||||
* 2) iterate the RequestOutputs
|
* 2) iterate the RequestOutputs
|
||||||
* 3) convert RequestOutput --> PDResponse
|
* 3) convert RequestOutput --> PDResponse
|
||||||
* 4) serialize and send to Connector.
|
* 4) serialize and send to PDClient
|
||||||
"""
|
"""
|
||||||
request_id = req.request_id
|
request_id = req.request_id
|
||||||
|
|
||||||
|
|||||||
@ -55,9 +55,3 @@ class PDGenerationResponse(msgspec.Struct):
|
|||||||
finish_reason=out.finish_reason,
|
finish_reason=out.finish_reason,
|
||||||
stop_reason=out.stop_reason,
|
stop_reason=out.stop_reason,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PDGenerationFailure(msgspec.Struct):
|
|
||||||
request_id: str
|
|
||||||
error_message: str
|
|
||||||
engine_dead: bool
|
|
||||||
|
|||||||
@ -1,5 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Toy Server For Prototyping."""
|
"""
|
||||||
|
Toy connector for prototyping.
|
||||||
|
|
||||||
|
When PDClient supports the protocol and we clean up the
|
||||||
|
OpenAI Server, we can drop this in favor of vllm serve.
|
||||||
|
"""
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@ -9,7 +14,7 @@ import uvloop
|
|||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from vllm.entrypoints.disaggregated.engine import PDEngine
|
from vllm.disaggregated.pd_client import PDClient
|
||||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
ErrorResponse)
|
ErrorResponse)
|
||||||
@ -46,34 +51,31 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def pd_engine_client_ctx_manager(
|
async def pd_client_ctx(prefill_addr: str, decode_addr: str,
|
||||||
prefill_addr: str, decode_addr: str, connector_addr: str,
|
connector_addr: str,
|
||||||
model_name: str) -> AsyncIterator[PDEngine]:
|
model_name: str) -> AsyncIterator[PDClient]:
|
||||||
engine = PDEngine(prefill_addr, decode_addr, connector_addr, model_name)
|
client = PDClient(prefill_addr, decode_addr, connector_addr, model_name)
|
||||||
yield engine
|
yield client
|
||||||
engine.shutdown()
|
client.shutdown()
|
||||||
|
|
||||||
|
|
||||||
async def main(args, **uvicorn_kwargs):
|
async def main(args, **uvicorn_kwargs):
|
||||||
logger.info("vLLM Disaggregate Connector Start %s %s", args,
|
logger.info("vLLM Disaggregated Connector Start %s %s", args,
|
||||||
uvicorn_kwargs)
|
uvicorn_kwargs)
|
||||||
|
|
||||||
# Avoid dropping requests under high concurrency.
|
# Avoid dropping requests under high concurrency.
|
||||||
set_ulimit()
|
set_ulimit()
|
||||||
|
|
||||||
# IPC Paths.
|
# IPC Paths.
|
||||||
# NOTE FOR DEVELOPERS: when shifting to TCP, ensure you
|
|
||||||
# are not using pickle to avoid RCE security flaw.
|
|
||||||
prefill_addr = f"ipc://{args.prefill_addr}"
|
prefill_addr = f"ipc://{args.prefill_addr}"
|
||||||
decode_addr = f"ipc://{args.decode_addr}"
|
decode_addr = f"ipc://{args.decode_addr}"
|
||||||
connector_addr = f"ipc://{args.connector_addr}"
|
connector_addr = f"ipc://{args.connector_addr}"
|
||||||
|
|
||||||
# Start Engine.
|
# Start Engine.
|
||||||
async with pd_engine_client_ctx_manager(
|
async with pd_client_ctx(prefill_addr=prefill_addr,
|
||||||
prefill_addr=prefill_addr,
|
decode_addr=decode_addr,
|
||||||
decode_addr=decode_addr,
|
connector_addr=connector_addr,
|
||||||
connector_addr=connector_addr,
|
model_name=args.model) as engine_client:
|
||||||
model_name=args.model) as engine_client:
|
|
||||||
|
|
||||||
# Initialize App State.
|
# Initialize App State.
|
||||||
model_config = await engine_client.get_model_config()
|
model_config = await engine_client.get_model_config()
|
||||||
@ -93,7 +95,7 @@ async def main(args, **uvicorn_kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Run Server.
|
# Run Server.
|
||||||
config = uvicorn.Config(app, host="0.0.0.0", port=args.port)
|
config = uvicorn.Config(app, host=args.host, port=args.port)
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
await server.serve()
|
await server.serve()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user