mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 16:57:54 +08:00
parent
f51f182d64
commit
cf64b0e6a7
0
vllm/disaggregated/__init__.py
Normal file
0
vllm/disaggregated/__init__.py
Normal file
129
vllm/disaggregated/pd_worker.py
Normal file
129
vllm/disaggregated/pd_worker.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
|
||||||
|
from vllm.disaggregated.protocol import (PDAbortRequest, PDGenerationRequest,
|
||||||
|
PDGenerationResponse, PDRequestType)
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PDWorker:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine: EngineClient,
|
||||||
|
worker_addr: str,
|
||||||
|
client_addr: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
PDWorker
|
||||||
|
* Wrapper around AsyncLLM to handle converting PRRequests
|
||||||
|
to PDResponse and sending back to the PDClient.
|
||||||
|
* Leverages ZMQ for communication with PDClient. We may
|
||||||
|
expand this in the future.
|
||||||
|
"""
|
||||||
|
# Engine.
|
||||||
|
self.engine = engine
|
||||||
|
|
||||||
|
# ZMQ IPC.
|
||||||
|
self.worker_addr = worker_addr
|
||||||
|
self.client_addr = client_addr
|
||||||
|
self.ctx = zmq.asyncio.Context()
|
||||||
|
self.from_client = self.ctx.socket(zmq.constants.PULL)
|
||||||
|
self.from_client.connect(f"ipc://{self.worker_addr}")
|
||||||
|
self.to_client = self.ctx.socket(zmq.constants.PUSH)
|
||||||
|
self.to_client.connect(f"ipc://{self.client_addr}")
|
||||||
|
self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest)
|
||||||
|
self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest)
|
||||||
|
self.encoder = msgspec.msgpack.Encoder()
|
||||||
|
|
||||||
|
# Active Requests.
|
||||||
|
self.running_requests: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
if hasattr(self, "ctx"):
|
||||||
|
self.ctx.destroy()
|
||||||
|
|
||||||
|
if hasattr(self, "running_requests"):
|
||||||
|
for running_request in self.running_requests:
|
||||||
|
running_request.cancel()
|
||||||
|
|
||||||
|
if hasattr(self, "client_addr"):
|
||||||
|
ipc_paths = [self.worker_addr, self.client_addr]
|
||||||
|
for ipc_path in ipc_paths:
|
||||||
|
socket_path = ipc_path.replace("ipc://", "")
|
||||||
|
if os.path.exists(socket_path):
|
||||||
|
os.remove(socket_path)
|
||||||
|
|
||||||
|
async def run_busy_loop(self):
|
||||||
|
"""
|
||||||
|
Main execution loop for the PDWorker:
|
||||||
|
* 1) Wait for a request from the PDClient.
|
||||||
|
* 2) Handle the request.
|
||||||
|
"""
|
||||||
|
logger.info("PDWorker is ready To handle requests.")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# 1) Get request from the Connector.
|
||||||
|
req_type, req_data = await self.from_client.recv_multipart()
|
||||||
|
|
||||||
|
# 2) Handle the request.
|
||||||
|
await self._handle_request(req_type, req_data.buffer)
|
||||||
|
|
||||||
|
async def _handle_request(self, req_type: bytes, req_data: bytes):
|
||||||
|
"""Parse the request type and call the appropriate handler."""
|
||||||
|
if req_type == PDRequestType.GENERATION:
|
||||||
|
req = self.decode_generation(req_data)
|
||||||
|
await self._generation_handler(req)
|
||||||
|
elif req_type == PDRequestType.ABORT:
|
||||||
|
req = self.decode_abort(req_data)
|
||||||
|
await self._abort_handler(req)
|
||||||
|
|
||||||
|
async def _generation_handler(self, req: PDGenerationRequest):
|
||||||
|
"""Launch generation in a background task."""
|
||||||
|
task = asyncio.create_task(self._generate(req))
|
||||||
|
self.running_requests.add(task)
|
||||||
|
task.add_done_callback(self.running_requests.discard)
|
||||||
|
|
||||||
|
async def _abort_handler(self, req: PDGenerationRequest):
|
||||||
|
"""Abort the request in the engine."""
|
||||||
|
# Convert running_requests set() into a dict(), keyed
|
||||||
|
# by request_id. Cancel the task when an abort comes in.
|
||||||
|
# Then update the _generate coroutine to handle a
|
||||||
|
# cancel error by aborting in the Engine.
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _generate(self, req: PDGenerationRequest):
|
||||||
|
"""
|
||||||
|
Handle a single PDRequest:
|
||||||
|
* 1) submit request to AsyncLLM
|
||||||
|
* 2) iterate the RequestOutputs
|
||||||
|
* 3) convert RequestOutput --> PDResponse
|
||||||
|
* 4) serialize and send to Connector.
|
||||||
|
"""
|
||||||
|
request_id = req.request_id
|
||||||
|
|
||||||
|
# 1) Submit request to Engine.
|
||||||
|
generator = self.engine.generate(
|
||||||
|
prompt={"prompt_token_ids": req.prompt_token_ids},
|
||||||
|
sampling_params=req.sampling_params,
|
||||||
|
request_id=request_id)
|
||||||
|
|
||||||
|
# 2) Iterate RequestOutputs.
|
||||||
|
async for request_output in generator:
|
||||||
|
# 3) Convert RequestOutput --> PDResponse.
|
||||||
|
response = PDGenerationResponse.from_request_output(request_output)
|
||||||
|
|
||||||
|
# 4) Serialize and send to PDClient.
|
||||||
|
response_bytes = self.encoder.encode(response)
|
||||||
|
logger.debug("Sending: %s", request_id)
|
||||||
|
await self.to_client.send(response_bytes, copy=False)
|
||||||
|
logger.debug("Sent: %s", request_id)
|
||||||
@ -1,37 +1,40 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Toy Server For Prototyping."""
|
||||||
import uvicorn
|
|
||||||
import uvloop
|
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
import uvloop
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from vllm.entrypoints.disaggregated.engine import PDEngine
|
from vllm.entrypoints.disaggregated.engine import PDEngine
|
||||||
|
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
ErrorResponse)
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||||
OpenAIServingModels)
|
OpenAIServingModels)
|
||||||
from vllm.entrypoints.openai.protocol import CompletionRequest
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket
|
from vllm.utils import FlexibleArgumentParser, set_ulimit
|
||||||
from vllm.entrypoints.openai.protocol import (
|
|
||||||
CompletionResponse, ErrorResponse)
|
|
||||||
|
|
||||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||||
logger = init_logger('vllm.entrypoints.disaggregated.api_server')
|
logger = init_logger('vllm.entrypoints.disaggregated.api_server')
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
async def show_available_models(raw_request: Request):
|
async def show_available_models(raw_request: Request):
|
||||||
handler: OpenAIServingModels = raw_request.app.state.openai_serving_models
|
handler: OpenAIServingModels = raw_request.app.state.openai_serving_models
|
||||||
models_ = await handler.show_available_models()
|
models_ = await handler.show_available_models()
|
||||||
return JSONResponse(content=models_.model_dump())
|
return JSONResponse(content=models_.model_dump())
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion
|
handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion # noqa: E501
|
||||||
generator = await handler.create_completion(request, raw_request)
|
generator = await handler.create_completion(request, raw_request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
@ -41,17 +44,16 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
|
|
||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def pd_engine_client_ctx_manager(
|
async def pd_engine_client_ctx_manager(
|
||||||
prefill_addr: str,
|
prefill_addr: str, decode_addr: str, connector_addr: str,
|
||||||
decode_addr: str,
|
model_name: str) -> AsyncIterator[PDEngine]:
|
||||||
connector_addr: str,
|
|
||||||
model_name: str
|
|
||||||
) -> AsyncIterator[PDEngine]:
|
|
||||||
engine = PDEngine(prefill_addr, decode_addr, connector_addr, model_name)
|
engine = PDEngine(prefill_addr, decode_addr, connector_addr, model_name)
|
||||||
yield engine
|
yield engine
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|
||||||
|
|
||||||
async def main(args, **uvicorn_kwargs):
|
async def main(args, **uvicorn_kwargs):
|
||||||
logger.info("vLLM Disaggregate Connector Start %s %s", args,
|
logger.info("vLLM Disaggregate Connector Start %s %s", args,
|
||||||
uvicorn_kwargs)
|
uvicorn_kwargs)
|
||||||
@ -71,17 +73,16 @@ async def main(args, **uvicorn_kwargs):
|
|||||||
prefill_addr=prefill_addr,
|
prefill_addr=prefill_addr,
|
||||||
decode_addr=decode_addr,
|
decode_addr=decode_addr,
|
||||||
connector_addr=connector_addr,
|
connector_addr=connector_addr,
|
||||||
model_name=args.model
|
model_name=args.model) as engine_client:
|
||||||
) as engine_client:
|
|
||||||
|
|
||||||
# Initialize App State.
|
# Initialize App State.
|
||||||
model_config = await engine_client.get_model_config()
|
model_config = await engine_client.get_model_config()
|
||||||
app.state.openai_serving_models = OpenAIServingModels(
|
app.state.openai_serving_models = OpenAIServingModels(
|
||||||
engine_client=engine_client,
|
engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=[BaseModelPath(
|
base_model_paths=[
|
||||||
name=args.served_model_name or args.model,
|
BaseModelPath(name=args.served_model_name or args.model,
|
||||||
model_path=args.model)
|
model_path=args.model)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
app.state.openai_serving_completion = OpenAIServingCompletion(
|
app.state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
@ -96,6 +97,7 @@ async def main(args, **uvicorn_kwargs):
|
|||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
await server.serve()
|
await server.serve()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="vLLM OpenAI-Compatible P/D Server.")
|
description="vLLM OpenAI-Compatible P/D Server.")
|
||||||
|
|||||||
@ -1,324 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from collections.abc import AsyncGenerator, Mapping
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import msgspec
|
|
||||||
import zmq
|
|
||||||
import zmq.asyncio
|
|
||||||
|
|
||||||
from vllm.config import DecodingConfig, ModelConfig
|
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
|
||||||
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
|
|
||||||
from vllm.inputs.data import PromptType
|
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
||||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
|
||||||
from vllm.pooling_params import PoolingParams
|
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
|
||||||
from vllm.utils import Device
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
DEFAULT_MAX_TOKENS = 32000
|
|
||||||
|
|
||||||
|
|
||||||
class PDEngine:
|
|
||||||
"""
|
|
||||||
PDEngine:
|
|
||||||
Equiavlent of AsyncLLM for P/D. Assumes there is
|
|
||||||
a Prefill and Decode service already running.
|
|
||||||
|
|
||||||
* TODO: actually handle errors and failure.
|
|
||||||
* TODO: support more than just text input.
|
|
||||||
* TODO: move under vllm/v1/engine one past prototype.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, prefill_addr: str, decode_addr: str,
|
|
||||||
connector_addr: str, model_name: str):
|
|
||||||
# Request queues.
|
|
||||||
self.queues: dict[str, asyncio.Queue] = {}
|
|
||||||
|
|
||||||
# Serialization encoder.
|
|
||||||
self.encoder = msgspec.msgpack.Encoder()
|
|
||||||
|
|
||||||
# ZMQ communication.
|
|
||||||
self.ctx = zmq.asyncio.Context()
|
|
||||||
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
|
|
||||||
self.to_decode.bind(f"{decode_addr}")
|
|
||||||
self.to_prefill = self.ctx.socket(zmq.constants.PUSH)
|
|
||||||
self.to_prefill.bind(f"{prefill_addr}")
|
|
||||||
self.connector_addr = connector_addr
|
|
||||||
self.decode_addr = decode_addr
|
|
||||||
self.prefill_addr = prefill_addr
|
|
||||||
|
|
||||||
# Background loops (started on first generate()).
|
|
||||||
self.output_handler: Optional[asyncio.Task] = None
|
|
||||||
self.log_running: Optional[asyncio.Task] = None
|
|
||||||
|
|
||||||
# Dummy: needed for EngineClient Protocol.
|
|
||||||
# TODO: refactor EngineClient to avoid needing this.
|
|
||||||
self.model_config = ModelConfig(model=model_name,
|
|
||||||
tokenizer=model_name,
|
|
||||||
tokenizer_mode="auto",
|
|
||||||
trust_remote_code=False,
|
|
||||||
dtype="auto",
|
|
||||||
task="generate",
|
|
||||||
seed=42)
|
|
||||||
|
|
||||||
# Dummy: needed for EngineClient Protocol.
|
|
||||||
# TODO: refactor EngineClient to avoid needing this.
|
|
||||||
init_kwargs = dict(
|
|
||||||
tokenizer_id=self.model_config.tokenizer,
|
|
||||||
enable_lora=False,
|
|
||||||
max_num_seqs=1024,
|
|
||||||
max_loras=0,
|
|
||||||
max_input_length=None,
|
|
||||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
|
||||||
revision=self.model_config.tokenizer_revision,
|
|
||||||
truncation_side=self.model_config.truncation_side)
|
|
||||||
self.tokenizer = TokenizerGroup(**init_kwargs)
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
if (ctx := self.ctx) is not None:
|
|
||||||
ctx.destroy(linger=0)
|
|
||||||
if (task := self.log_running) is not None:
|
|
||||||
task.cancel()
|
|
||||||
if (task := self.output_handler) is not None:
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
ipc_paths = [self.connector_addr, self.decode_addr, self.prefill_addr]
|
|
||||||
for path in ipc_paths:
|
|
||||||
socket_path = path.replace("ipc://", "")
|
|
||||||
if os.path.exists(socket_path):
|
|
||||||
os.remove(socket_path)
|
|
||||||
|
|
||||||
async def _run_log_running(self):
|
|
||||||
logger.info("Running requests: %d", len(self.queues))
|
|
||||||
await asyncio.sleep(10.)
|
|
||||||
|
|
||||||
async def _run_output_handler(self):
|
|
||||||
"""
|
|
||||||
Pull responses from Decode + Prefill engines and
|
|
||||||
distribute back to the generate() tasks.
|
|
||||||
"""
|
|
||||||
decoder = msgspec.msgpack.Decoder(PDResponse)
|
|
||||||
|
|
||||||
socket: Optional[zmq.asyncio.Socket] = None
|
|
||||||
try:
|
|
||||||
socket = self.ctx.socket(zmq.constants.PULL)
|
|
||||||
socket.bind(self.connector_addr)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
reponse_bytes = await socket.recv()
|
|
||||||
response = decoder.decode(reponse_bytes)
|
|
||||||
logger.debug("Got Response: %s", response.request_id)
|
|
||||||
self.queues[response.request_id].put_nowait(response)
|
|
||||||
except:
|
|
||||||
# TODO: actually handle failure and shutdown.
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
if socket is not None:
|
|
||||||
socket.close(linger=0)
|
|
||||||
|
|
||||||
async def _prefill(
|
|
||||||
self,
|
|
||||||
request: PDRequest,
|
|
||||||
q: asyncio.Queue[PDResponse],
|
|
||||||
):
|
|
||||||
# Send request to the prefill instance.
|
|
||||||
req_bytes = self.encoder.encode(request)
|
|
||||||
await self.to_prefill.send(req_bytes, copy=False)
|
|
||||||
|
|
||||||
# Wait for the prefill to be done.
|
|
||||||
response = await q.get()
|
|
||||||
assert response.request_id == request.request_id
|
|
||||||
if not response.success:
|
|
||||||
# TODO: actual error handling and shutdown.
|
|
||||||
raise Exception("Failed Prefill Request.")
|
|
||||||
|
|
||||||
async def _decode(
|
|
||||||
self,
|
|
||||||
request: PDRequest,
|
|
||||||
q: asyncio.Queue[PDResponse],
|
|
||||||
) -> AsyncGenerator[PDRequest]:
|
|
||||||
|
|
||||||
# Send request to the decode instance.
|
|
||||||
req_bytes = self.encoder.encode(request)
|
|
||||||
await self.to_decode.send(req_bytes, copy=False)
|
|
||||||
|
|
||||||
# Iterate response queue and yield each response to caller..
|
|
||||||
finished = False
|
|
||||||
while not finished:
|
|
||||||
response = await q.get()
|
|
||||||
if not response.success:
|
|
||||||
# TODO: actual error handling and shutdown.
|
|
||||||
raise Exception("Failed Decode Request.")
|
|
||||||
finished = response.finish_reason is not None
|
|
||||||
yield response
|
|
||||||
|
|
||||||
def _to_request_output(
|
|
||||||
self,
|
|
||||||
pd_response: PDResponse,
|
|
||||||
prompt_token_ids: list[int],
|
|
||||||
) -> RequestOutput:
|
|
||||||
finished = pd_response.finish_reason is not None
|
|
||||||
return RequestOutput(
|
|
||||||
request_id=pd_response.request_id,
|
|
||||||
prompt=None,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
prompt_logprobs=None,
|
|
||||||
outputs=[
|
|
||||||
CompletionOutput(index=0,
|
|
||||||
text=pd_response.text,
|
|
||||||
token_ids=pd_response.token_ids,
|
|
||||||
cumulative_logprob=None,
|
|
||||||
logprobs=None,
|
|
||||||
finish_reason=pd_response.finish_reason,
|
|
||||||
stop_reason=pd_response.stop_reason)
|
|
||||||
],
|
|
||||||
finished=finished,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self,
|
|
||||||
prompt: PromptType,
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
request_id: str,
|
|
||||||
lora_request: Optional[LoRARequest] = None,
|
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
||||||
priority: int = 0,
|
|
||||||
) -> AsyncGenerator[RequestOutput]:
|
|
||||||
# Start loops on first request.
|
|
||||||
if self.output_handler is None:
|
|
||||||
self.output_handler = asyncio.create_task(
|
|
||||||
self._run_output_handler())
|
|
||||||
self.log_running = asyncio.create_task(self._run_log_running())
|
|
||||||
|
|
||||||
# TODO: Expand to support the full matrix.
|
|
||||||
if "prompt_token_ids" not in prompt:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We currently only support TokensPrompt for P/D!")
|
|
||||||
if lora_request is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We currently do not support LoRA for P/D!")
|
|
||||||
if trace_headers is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We currently do not support tracing for P/D!")
|
|
||||||
if prompt_adapter_request is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We currently do not support prompt adapter for P/D!")
|
|
||||||
if priority != 0:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We currently do not support priority for P/D!")
|
|
||||||
if request_id in self.queues:
|
|
||||||
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
|
||||||
|
|
||||||
# Queue to gather output from output_handler.
|
|
||||||
q: asyncio.Queue[PDResponse] = asyncio.Queue()
|
|
||||||
self.queues[request_id] = q
|
|
||||||
|
|
||||||
# (1) Perform the Prefill.
|
|
||||||
original_max_tokens = sampling_params.max_tokens
|
|
||||||
prompt_token_ids = prompt["prompt_token_ids"]
|
|
||||||
request = PDRequest(request_id=request_id,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
sampling_params=sampling_params)
|
|
||||||
request.sampling_params.max_tokens = 1
|
|
||||||
logger.debug("Sending Prefill: %s", request.request_id)
|
|
||||||
pd_response = await self._prefill(request, q)
|
|
||||||
|
|
||||||
# (2) Perform the Decodes.
|
|
||||||
logger.debug("Sending Decode: %s", request.request_id)
|
|
||||||
request.sampling_params.max_tokens = original_max_tokens
|
|
||||||
async for pd_response in self._decode(request, q):
|
|
||||||
logger.debug("Got Decode: %s", request.request_id)
|
|
||||||
yield self._to_request_output(pd_response, prompt_token_ids)
|
|
||||||
|
|
||||||
async def beam_search(
|
|
||||||
self,
|
|
||||||
prompt: PromptType,
|
|
||||||
request_id: str,
|
|
||||||
params: BeamSearchParams,
|
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
prompt: PromptType,
|
|
||||||
pooling_params: PoolingParams,
|
|
||||||
request_id: str,
|
|
||||||
lora_request: Optional[LoRARequest] = None,
|
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
|
||||||
priority: int = 0,
|
|
||||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_model_config(self) -> ModelConfig:
|
|
||||||
return self.model_config
|
|
||||||
|
|
||||||
async def get_decoding_config(self) -> DecodingConfig:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_tokenizer(
|
|
||||||
self,
|
|
||||||
lora_request: Optional[LoRARequest] = None,
|
|
||||||
) -> AnyTokenizer:
|
|
||||||
if lora_request is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"LoRA is not yet supported in the PDEngine.")
|
|
||||||
return self.tokenizer.get_lora_tokenizer(None)
|
|
||||||
|
|
||||||
async def is_tracing_enabled(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def do_log_stats(
|
|
||||||
self,
|
|
||||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
|
||||||
model_output: Optional[list[SamplerOutput]] = None,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def check_health(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def start_profile(self) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def stop_profile(self) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def reset_prefix_cache(self,
|
|
||||||
device: Optional[Device] = None) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def sleep(self, level: int = 1) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def wake_up(self) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def is_sleeping(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def errored(self) -> bool:
|
|
||||||
return False
|
|
||||||
@ -1,38 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import msgspec
|
|
||||||
|
|
||||||
from vllm import SamplingParams
|
|
||||||
|
|
||||||
# NOTE FOR DEVELOPERS:
|
|
||||||
# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE
|
|
||||||
# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE
|
|
||||||
# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS.
|
|
||||||
|
|
||||||
|
|
||||||
class PDRequest(
|
|
||||||
msgspec.Struct,
|
|
||||||
array_like=True, # type: ignore[call-arg]
|
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
|
||||||
gc=False): # type: ignore[call-arg]
|
|
||||||
request_id: str
|
|
||||||
prompt_token_ids: list[int]
|
|
||||||
sampling_params: SamplingParams
|
|
||||||
# TODO: support multimodal inputs.
|
|
||||||
|
|
||||||
|
|
||||||
class PDResponse(
|
|
||||||
msgspec.Struct,
|
|
||||||
array_like=True, # type: ignore[call-arg]
|
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
|
||||||
gc=False): # type: ignore[call-arg]
|
|
||||||
request_id: str
|
|
||||||
success: bool
|
|
||||||
text: str
|
|
||||||
token_ids: list[int]
|
|
||||||
finish_reason: Optional[str] = None
|
|
||||||
stop_reason: Optional[str] = None
|
|
||||||
# TODO: support full protocol.
|
|
||||||
logprobs = None
|
|
||||||
@ -1,136 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import msgspec
|
|
||||||
import signal
|
|
||||||
import uvloop
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import zmq
|
|
||||||
import zmq.asyncio
|
|
||||||
|
|
||||||
from vllm.inputs.data import TokensPrompt
|
|
||||||
from vllm.engine.async_llm_engine import AsyncEngineArgs
|
|
||||||
from vllm.engine.protocol import EngineClient
|
|
||||||
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
|
|
||||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.utils import FlexibleArgumentParser, set_ulimit
|
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
|
||||||
|
|
||||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
|
||||||
logger = init_logger('vllm.entrypoints.disaggregated.worker')
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_request(
|
|
||||||
request: PDRequest,
|
|
||||||
engine: EngineClient,
|
|
||||||
socket: zmq.asyncio.Socket,
|
|
||||||
encoder: msgspec.msgpack.Encoder,
|
|
||||||
) -> None:
|
|
||||||
request_id = request.request_id
|
|
||||||
try:
|
|
||||||
# 1) Generate RequestOutputs.
|
|
||||||
prompt: TokensPrompt = {
|
|
||||||
"prompt_token_ids": request.prompt_token_ids}
|
|
||||||
async for request_output in engine.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
sampling_params=request.sampling_params,
|
|
||||||
request_id=request_id):
|
|
||||||
|
|
||||||
assert len(request_output.outputs) == 1, "Only support N=1 right now."
|
|
||||||
out = request_output.outputs[0]
|
|
||||||
|
|
||||||
# 2) Convert RequestOutput --> PDResponse.
|
|
||||||
response = PDResponse(
|
|
||||||
request_id=request_id,
|
|
||||||
success=True,
|
|
||||||
text=out.text,
|
|
||||||
token_ids=out.token_ids,
|
|
||||||
finish_reason=out.finish_reason,
|
|
||||||
stop_reason=out.stop_reason,
|
|
||||||
)
|
|
||||||
response_bytes = encoder.encode(response)
|
|
||||||
|
|
||||||
# 3) Send to Connector.
|
|
||||||
logger.info("Sending: %s", request_id)
|
|
||||||
await socket.send(response_bytes, copy=False)
|
|
||||||
logger.info("Sent: %s", request_id)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# TODO: actual error handling.
|
|
||||||
logger.error("Exception in Worker Routine: %s request_id: %s", e,
|
|
||||||
request_id)
|
|
||||||
response = PDResponse(request_id=request_id, success=False)
|
|
||||||
response_bytes = encoder.encode(response)
|
|
||||||
await socket.send(response, copy=False)
|
|
||||||
|
|
||||||
async def run_server(args, engine: EngineClient):
|
|
||||||
"""Get Requests and Handle Them."""
|
|
||||||
logger.info("P/D Worker is Ready To Recieve Requests.")
|
|
||||||
|
|
||||||
running_requests: set[asyncio.Task] = set()
|
|
||||||
decoder = msgspec.msgpack.Decoder(PDRequest)
|
|
||||||
encoder = msgspec.msgpack.Encoder()
|
|
||||||
|
|
||||||
ctx: Optional[zmq.asyncio.Context] = None
|
|
||||||
try:
|
|
||||||
# IPC Setup.
|
|
||||||
ctx = zmq.asyncio.Context()
|
|
||||||
from_connector = ctx.socket(zmq.constants.PULL)
|
|
||||||
from_connector.connect(f"ipc://{args.worker_addr}")
|
|
||||||
to_connector = ctx.socket(zmq.constants.PUSH)
|
|
||||||
to_connector.connect(f"ipc://{args.connector_addr}")
|
|
||||||
|
|
||||||
# Main Loop.
|
|
||||||
while True:
|
|
||||||
# 1) Get request from the Connector.
|
|
||||||
pd_request_bytes = await from_connector.recv()
|
|
||||||
pd_request = decoder.decode(pd_request_bytes)
|
|
||||||
|
|
||||||
# 2) Launch a coroutine to handle the request.
|
|
||||||
task = asyncio.create_task(handle_request(
|
|
||||||
pd_request, engine, to_connector, encoder))
|
|
||||||
running_requests.add(task)
|
|
||||||
task.add_done_callback(running_requests.discard)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.debug("Worker server loop interrupted.")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
for task in running_requests:
|
|
||||||
task.cancel()
|
|
||||||
if ctx is not None:
|
|
||||||
ctx.destroy(linger=0)
|
|
||||||
|
|
||||||
|
|
||||||
async def main(args) -> None:
|
|
||||||
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
|
|
||||||
logger.info("args: %s", args)
|
|
||||||
|
|
||||||
# Workaround to avoid footguns where uvicorn drops requests
|
|
||||||
# with too many concurrent requests active due to ulimit.
|
|
||||||
set_ulimit()
|
|
||||||
|
|
||||||
# Interrupt on sigterm during initialization.
|
|
||||||
def signal_handler(*_) -> None:
|
|
||||||
raise KeyboardInterrupt("terminated")
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
args.disable_frontend_multiprocessing = False
|
|
||||||
async with build_async_engine_client(args) as engine:
|
|
||||||
await run_server(args, engine)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = FlexibleArgumentParser()
|
|
||||||
parser.add_argument('--connector-addr',
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help='The address of the connector.')
|
|
||||||
parser.add_argument('--worker-addr',
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help='The address of the worker.')
|
|
||||||
AsyncEngineArgs.add_cli_args(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
uvloop.run(main(args))
|
|
||||||
Loading…
x
Reference in New Issue
Block a user