mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 03:07:03 +08:00
parent
7954461d4c
commit
70e06dd574
@ -37,7 +37,7 @@ wait_for_server() {
|
||||
wait_for_disagg_server() {
|
||||
local log_file=$1
|
||||
timeout 1200 bash -c "
|
||||
until grep -q 'zmq Server started at' $log_file; do
|
||||
until grep -q 'PD Worker is ready' $log_file; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
@ -53,9 +53,9 @@ class PDController(EngineClient):
|
||||
|
|
||||
[ zmq ]
|
||||
|
|
||||
[ PDWorker ] [ PDWorker ]
|
||||
|
|
||||
[ Engine ] <---> [ Engine ]
|
||||
[ PDWorker ] [ PDWorker ]
|
||||
| |
|
||||
[ Engine ] <-kv-> [ Engine ]
|
||||
|
||||
After PR #12957, we will support xPyD, so we will
|
||||
also need to implement a scheduler and service
|
||||
@ -80,13 +80,12 @@ class PDController(EngineClient):
|
||||
# TODO: once https://github.com/vllm-project/vllm/pull/12957
|
||||
# lands, do service discovery to scale out workers.
|
||||
self.ctx = zmq.asyncio.Context()
|
||||
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.to_decode.bind(f"{decode_addr}")
|
||||
self.to_prefill = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.to_prefill.bind(f"{prefill_addr}")
|
||||
self.to_prefill.connect(prefill_addr)
|
||||
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.to_decode.connect(decode_addr)
|
||||
self.controller_addr = controller_addr
|
||||
self.decode_addr = decode_addr
|
||||
self.prefill_addr = prefill_addr
|
||||
self.ipc_paths = [prefill_addr, decode_addr, controller_addr]
|
||||
|
||||
# Background loops (started on first generate()).
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
@ -123,8 +122,7 @@ class PDController(EngineClient):
|
||||
if (task := self.output_handler) is not None:
|
||||
task.cancel()
|
||||
|
||||
ipc_paths = [self.controller_addr, self.decode_addr, self.prefill_addr]
|
||||
for path in ipc_paths:
|
||||
for path in self.ipc_paths:
|
||||
socket_path = path.replace("ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
@ -178,7 +176,7 @@ class PDController(EngineClient):
|
||||
response = await q.get()
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
logger.debug("Got Decode Response: %s", request.request_id)
|
||||
logger.debug("Prefill Response: %s", request.request_id)
|
||||
|
||||
async def _run_decode(
|
||||
self,
|
||||
@ -196,7 +194,7 @@ class PDController(EngineClient):
|
||||
response = await q.get()
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
logger.debug("Got Decode Response: %s", request.request_id)
|
||||
logger.debug("Decode Response: %s", request.request_id)
|
||||
finished = response.finish_reason is not None
|
||||
yield response
|
||||
|
||||
@ -269,11 +267,9 @@ class PDController(EngineClient):
|
||||
prompt_token_ids=prompt["prompt_token_ids"],
|
||||
sampling_params=sampling_params)
|
||||
request.sampling_params.max_tokens = 1
|
||||
logger.debug("Sending Prefill: %s", request.request_id)
|
||||
pd_response = await self._run_prefill(request, q)
|
||||
|
||||
# (2) Perform the Decodes.
|
||||
logger.debug("Sending Decode: %s", request.request_id)
|
||||
request.sampling_params.max_tokens = original_max_tokens
|
||||
async for pd_response in self._run_decode(request, q):
|
||||
yield self._to_request_output(pd_response,
|
||||
|
||||
@ -8,7 +8,8 @@ import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.disaggregated.protocol import (PDAbortRequest, PDGenerationRequest,
|
||||
PDGenerationResponse, PDRequestType)
|
||||
PDGenerationResponse, PDRequestType,
|
||||
PDResponseType)
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -34,13 +35,13 @@ class PDWorker:
|
||||
self.engine = engine
|
||||
|
||||
# ZMQ IPC.
|
||||
self.worker_addr = worker_addr
|
||||
self.controller_addr = controller_addr
|
||||
self.worker_addr = f"ipc://{worker_addr}"
|
||||
self.controller_addr = f"ipc://{controller_addr}"
|
||||
self.ctx = zmq.asyncio.Context()
|
||||
self.from_client = self.ctx.socket(zmq.constants.PULL)
|
||||
self.from_client.connect(f"ipc://{self.worker_addr}")
|
||||
self.to_client = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.to_client.connect(f"ipc://{self.controller_addr}")
|
||||
self.from_controller = self.ctx.socket(zmq.constants.PULL)
|
||||
self.from_controller.bind(self.worker_addr)
|
||||
self.to_controller = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.to_controller.connect(self.controller_addr)
|
||||
self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest)
|
||||
self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest)
|
||||
self.encoder = msgspec.msgpack.Encoder()
|
||||
@ -71,9 +72,12 @@ class PDWorker:
|
||||
"""
|
||||
logger.info("PDWorker is ready To handle requests.")
|
||||
|
||||
poller = zmq.asyncio.Poller()
|
||||
poller.register(self.from_controller, zmq.POLLIN)
|
||||
|
||||
while True:
|
||||
# 1) Get request from the Connector.
|
||||
req_type, req_data = await self.from_client.recv_multipart()
|
||||
req_type, req_data = await self.from_controller.recv_multipart()
|
||||
|
||||
# 2) Handle the request.
|
||||
await self._handle_request(req_type, req_data)
|
||||
@ -85,11 +89,13 @@ class PDWorker:
|
||||
2) call the appropriate handler for the request type
|
||||
"""
|
||||
if req_type == PDRequestType.GENERATION:
|
||||
req = self.decode_generation(req_data)
|
||||
req = self.decode_generation.decode(req_data)
|
||||
await self._generation_handler(req)
|
||||
elif req_type == PDRequestType.ABORT:
|
||||
req = self.decode_abort(req_data)
|
||||
req = self.decode_abort.decode(req_data)
|
||||
await self._abort_handler(req)
|
||||
else:
|
||||
raise Exception(f"Unknown Request Type: {req_type}.")
|
||||
|
||||
async def _generation_handler(self, req: PDGenerationRequest):
|
||||
"""
|
||||
@ -133,7 +139,5 @@ class PDWorker:
|
||||
|
||||
# 4) Serialize and send to PDConroller.
|
||||
response_bytes = self.encoder.encode(response)
|
||||
msg = [PDGenerationResponse.SUCCE, response_bytes]
|
||||
logger.debug("Sending: %s", request_id)
|
||||
await self.to_client.send_multipart(msg, copy=False)
|
||||
logger.debug("Sent: %s", request_id)
|
||||
msg = (PDResponseType.GENERATION, response_bytes)
|
||||
await self.to_controller.send_multipart(msg, copy=False)
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Toy connector for prototyping.
|
||||
Toy API Server for prototyping.
|
||||
|
||||
When PDConroller supports the protocol and we clean up the
|
||||
OpenAI Server, we can drop this in favor of vllm serve.
|
||||
Once the PDController is more mature and we clean up
|
||||
the OpenAI Server at bit, we can put the PDController
|
||||
directly inside and launch with vllm serve.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
@ -10,12 +10,14 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
logger = init_logger("vllm.entrypoints.disaggregated.worker")
|
||||
|
||||
|
||||
async def run(args, engine: EngineClient):
|
||||
try:
|
||||
worker = PDWorker(engine, args.worker_addr, args.controller_addr)
|
||||
worker = PDWorker(engine=engine,
|
||||
worker_addr=args.worker_addr,
|
||||
controller_addr=args.controller_addr)
|
||||
await worker.run_busy_loop()
|
||||
finally:
|
||||
worker.shutdown()
|
||||
@ -25,7 +27,6 @@ async def main(args) -> None:
|
||||
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
|
||||
logger.info("Args: %s", args)
|
||||
|
||||
args.disable_frontend_multiprocessing = False
|
||||
async with build_async_engine_client(args) as engine:
|
||||
await run(args, engine)
|
||||
|
||||
@ -40,5 +41,8 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
required=True,
|
||||
help='The address of the worker.')
|
||||
parser.add_argument('--disable-frontend-multiprocessing',
|
||||
action="store_true",
|
||||
help='Disable MQLLMEngine for AsyncLLMEngine.')
|
||||
AsyncEngineArgs.add_cli_args(parser)
|
||||
uvloop.run(main(parser.parse_args()))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user