mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 22:09:06 +08:00
parent
7954461d4c
commit
70e06dd574
@ -37,7 +37,7 @@ wait_for_server() {
|
|||||||
wait_for_disagg_server() {
|
wait_for_disagg_server() {
|
||||||
local log_file=$1
|
local log_file=$1
|
||||||
timeout 1200 bash -c "
|
timeout 1200 bash -c "
|
||||||
until grep -q 'zmq Server started at' $log_file; do
|
until grep -q 'PD Worker is ready' $log_file; do
|
||||||
sleep 1
|
sleep 1
|
||||||
done" && return 0 || return 1
|
done" && return 0 || return 1
|
||||||
}
|
}
|
||||||
|
|||||||
@ -53,9 +53,9 @@ class PDController(EngineClient):
|
|||||||
|
|
|
|
||||||
[ zmq ]
|
[ zmq ]
|
||||||
|
|
|
|
||||||
[ PDWorker ] [ PDWorker ]
|
[ PDWorker ] [ PDWorker ]
|
||||||
|
|
| |
|
||||||
[ Engine ] <---> [ Engine ]
|
[ Engine ] <-kv-> [ Engine ]
|
||||||
|
|
||||||
After PR #12957, we will support xPyD, so we will
|
After PR #12957, we will support xPyD, so we will
|
||||||
also need to implement a scheduler and service
|
also need to implement a scheduler and service
|
||||||
@ -80,13 +80,12 @@ class PDController(EngineClient):
|
|||||||
# TODO: once https://github.com/vllm-project/vllm/pull/12957
|
# TODO: once https://github.com/vllm-project/vllm/pull/12957
|
||||||
# lands, do service discovery to scale out workers.
|
# lands, do service discovery to scale out workers.
|
||||||
self.ctx = zmq.asyncio.Context()
|
self.ctx = zmq.asyncio.Context()
|
||||||
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
|
|
||||||
self.to_decode.bind(f"{decode_addr}")
|
|
||||||
self.to_prefill = self.ctx.socket(zmq.constants.PUSH)
|
self.to_prefill = self.ctx.socket(zmq.constants.PUSH)
|
||||||
self.to_prefill.bind(f"{prefill_addr}")
|
self.to_prefill.connect(prefill_addr)
|
||||||
|
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
|
||||||
|
self.to_decode.connect(decode_addr)
|
||||||
self.controller_addr = controller_addr
|
self.controller_addr = controller_addr
|
||||||
self.decode_addr = decode_addr
|
self.ipc_paths = [prefill_addr, decode_addr, controller_addr]
|
||||||
self.prefill_addr = prefill_addr
|
|
||||||
|
|
||||||
# Background loops (started on first generate()).
|
# Background loops (started on first generate()).
|
||||||
self.output_handler: Optional[asyncio.Task] = None
|
self.output_handler: Optional[asyncio.Task] = None
|
||||||
@ -123,8 +122,7 @@ class PDController(EngineClient):
|
|||||||
if (task := self.output_handler) is not None:
|
if (task := self.output_handler) is not None:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
ipc_paths = [self.controller_addr, self.decode_addr, self.prefill_addr]
|
for path in self.ipc_paths:
|
||||||
for path in ipc_paths:
|
|
||||||
socket_path = path.replace("ipc://", "")
|
socket_path = path.replace("ipc://", "")
|
||||||
if os.path.exists(socket_path):
|
if os.path.exists(socket_path):
|
||||||
os.remove(socket_path)
|
os.remove(socket_path)
|
||||||
@ -178,7 +176,7 @@ class PDController(EngineClient):
|
|||||||
response = await q.get()
|
response = await q.get()
|
||||||
if isinstance(response, Exception):
|
if isinstance(response, Exception):
|
||||||
raise response
|
raise response
|
||||||
logger.debug("Got Decode Response: %s", request.request_id)
|
logger.debug("Prefill Response: %s", request.request_id)
|
||||||
|
|
||||||
async def _run_decode(
|
async def _run_decode(
|
||||||
self,
|
self,
|
||||||
@ -196,7 +194,7 @@ class PDController(EngineClient):
|
|||||||
response = await q.get()
|
response = await q.get()
|
||||||
if isinstance(response, Exception):
|
if isinstance(response, Exception):
|
||||||
raise response
|
raise response
|
||||||
logger.debug("Got Decode Response: %s", request.request_id)
|
logger.debug("Decode Response: %s", request.request_id)
|
||||||
finished = response.finish_reason is not None
|
finished = response.finish_reason is not None
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
@ -269,11 +267,9 @@ class PDController(EngineClient):
|
|||||||
prompt_token_ids=prompt["prompt_token_ids"],
|
prompt_token_ids=prompt["prompt_token_ids"],
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
request.sampling_params.max_tokens = 1
|
request.sampling_params.max_tokens = 1
|
||||||
logger.debug("Sending Prefill: %s", request.request_id)
|
|
||||||
pd_response = await self._run_prefill(request, q)
|
pd_response = await self._run_prefill(request, q)
|
||||||
|
|
||||||
# (2) Perform the Decodes.
|
# (2) Perform the Decodes.
|
||||||
logger.debug("Sending Decode: %s", request.request_id)
|
|
||||||
request.sampling_params.max_tokens = original_max_tokens
|
request.sampling_params.max_tokens = original_max_tokens
|
||||||
async for pd_response in self._run_decode(request, q):
|
async for pd_response in self._run_decode(request, q):
|
||||||
yield self._to_request_output(pd_response,
|
yield self._to_request_output(pd_response,
|
||||||
|
|||||||
@ -8,7 +8,8 @@ import zmq
|
|||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
|
||||||
from vllm.disaggregated.protocol import (PDAbortRequest, PDGenerationRequest,
|
from vllm.disaggregated.protocol import (PDAbortRequest, PDGenerationRequest,
|
||||||
PDGenerationResponse, PDRequestType)
|
PDGenerationResponse, PDRequestType,
|
||||||
|
PDResponseType)
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@ -34,13 +35,13 @@ class PDWorker:
|
|||||||
self.engine = engine
|
self.engine = engine
|
||||||
|
|
||||||
# ZMQ IPC.
|
# ZMQ IPC.
|
||||||
self.worker_addr = worker_addr
|
self.worker_addr = f"ipc://{worker_addr}"
|
||||||
self.controller_addr = controller_addr
|
self.controller_addr = f"ipc://{controller_addr}"
|
||||||
self.ctx = zmq.asyncio.Context()
|
self.ctx = zmq.asyncio.Context()
|
||||||
self.from_client = self.ctx.socket(zmq.constants.PULL)
|
self.from_controller = self.ctx.socket(zmq.constants.PULL)
|
||||||
self.from_client.connect(f"ipc://{self.worker_addr}")
|
self.from_controller.bind(self.worker_addr)
|
||||||
self.to_client = self.ctx.socket(zmq.constants.PUSH)
|
self.to_controller = self.ctx.socket(zmq.constants.PUSH)
|
||||||
self.to_client.connect(f"ipc://{self.controller_addr}")
|
self.to_controller.connect(self.controller_addr)
|
||||||
self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest)
|
self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest)
|
||||||
self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest)
|
self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest)
|
||||||
self.encoder = msgspec.msgpack.Encoder()
|
self.encoder = msgspec.msgpack.Encoder()
|
||||||
@ -71,9 +72,12 @@ class PDWorker:
|
|||||||
"""
|
"""
|
||||||
logger.info("PDWorker is ready To handle requests.")
|
logger.info("PDWorker is ready To handle requests.")
|
||||||
|
|
||||||
|
poller = zmq.asyncio.Poller()
|
||||||
|
poller.register(self.from_controller, zmq.POLLIN)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# 1) Get request from the Connector.
|
# 1) Get request from the Connector.
|
||||||
req_type, req_data = await self.from_client.recv_multipart()
|
req_type, req_data = await self.from_controller.recv_multipart()
|
||||||
|
|
||||||
# 2) Handle the request.
|
# 2) Handle the request.
|
||||||
await self._handle_request(req_type, req_data)
|
await self._handle_request(req_type, req_data)
|
||||||
@ -85,11 +89,13 @@ class PDWorker:
|
|||||||
2) call the appropriate handler for the request type
|
2) call the appropriate handler for the request type
|
||||||
"""
|
"""
|
||||||
if req_type == PDRequestType.GENERATION:
|
if req_type == PDRequestType.GENERATION:
|
||||||
req = self.decode_generation(req_data)
|
req = self.decode_generation.decode(req_data)
|
||||||
await self._generation_handler(req)
|
await self._generation_handler(req)
|
||||||
elif req_type == PDRequestType.ABORT:
|
elif req_type == PDRequestType.ABORT:
|
||||||
req = self.decode_abort(req_data)
|
req = self.decode_abort.decode(req_data)
|
||||||
await self._abort_handler(req)
|
await self._abort_handler(req)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown Request Type: {req_type}.")
|
||||||
|
|
||||||
async def _generation_handler(self, req: PDGenerationRequest):
|
async def _generation_handler(self, req: PDGenerationRequest):
|
||||||
"""
|
"""
|
||||||
@ -133,7 +139,5 @@ class PDWorker:
|
|||||||
|
|
||||||
# 4) Serialize and send to PDConroller.
|
# 4) Serialize and send to PDConroller.
|
||||||
response_bytes = self.encoder.encode(response)
|
response_bytes = self.encoder.encode(response)
|
||||||
msg = [PDGenerationResponse.SUCCE, response_bytes]
|
msg = (PDResponseType.GENERATION, response_bytes)
|
||||||
logger.debug("Sending: %s", request_id)
|
await self.to_controller.send_multipart(msg, copy=False)
|
||||||
await self.to_client.send_multipart(msg, copy=False)
|
|
||||||
logger.debug("Sent: %s", request_id)
|
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""
|
"""
|
||||||
Toy connector for prototyping.
|
Toy API Server for prototyping.
|
||||||
|
|
||||||
When PDConroller supports the protocol and we clean up the
|
Once the PDController is more mature and we clean up
|
||||||
OpenAI Server, we can drop this in favor of vllm serve.
|
the OpenAI Server at bit, we can put the PDController
|
||||||
|
directly inside and launch with vllm serve.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|||||||
@ -10,12 +10,14 @@ from vllm.logger import init_logger
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger("vllm.entrypoints.disaggregated.worker")
|
||||||
|
|
||||||
|
|
||||||
async def run(args, engine: EngineClient):
|
async def run(args, engine: EngineClient):
|
||||||
try:
|
try:
|
||||||
worker = PDWorker(engine, args.worker_addr, args.controller_addr)
|
worker = PDWorker(engine=engine,
|
||||||
|
worker_addr=args.worker_addr,
|
||||||
|
controller_addr=args.controller_addr)
|
||||||
await worker.run_busy_loop()
|
await worker.run_busy_loop()
|
||||||
finally:
|
finally:
|
||||||
worker.shutdown()
|
worker.shutdown()
|
||||||
@ -25,7 +27,6 @@ async def main(args) -> None:
|
|||||||
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
|
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
|
||||||
logger.info("Args: %s", args)
|
logger.info("Args: %s", args)
|
||||||
|
|
||||||
args.disable_frontend_multiprocessing = False
|
|
||||||
async with build_async_engine_client(args) as engine:
|
async with build_async_engine_client(args) as engine:
|
||||||
await run(args, engine)
|
await run(args, engine)
|
||||||
|
|
||||||
@ -40,5 +41,8 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help='The address of the worker.')
|
help='The address of the worker.')
|
||||||
|
parser.add_argument('--disable-frontend-multiprocessing',
|
||||||
|
action="store_true",
|
||||||
|
help='Disable MQLLMEngine for AsyncLLMEngine.')
|
||||||
AsyncEngineArgs.add_cli_args(parser)
|
AsyncEngineArgs.add_cli_args(parser)
|
||||||
uvloop.run(main(parser.parse_args()))
|
uvloop.run(main(parser.parse_args()))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user