mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 18:07:04 +08:00
added __init__.py
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
parent
d5b0db449e
commit
284d5df45b
@ -45,14 +45,14 @@ wait_for_disagg_server() {
|
|||||||
|
|
||||||
# You can also adjust --kv-ip and --kv-port for distributed inference.
|
# You can also adjust --kv-ip and --kv-port for distributed inference.
|
||||||
MODEL=meta-llama/Llama-3.1-8B-Instruct
|
MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||||
CONNECTOR_ADDR=connectoripc
|
contoller_addr=connectoripc
|
||||||
PREFILL_WORKER_ADDR=prefillipc
|
PREFILL_WORKER_ADDR=prefillipc
|
||||||
DECODE_WORKER_ADDR=decodeipc
|
DECODE_WORKER_ADDR=decodeipc
|
||||||
|
|
||||||
# prefilling instance, which is the KV producer
|
# prefilling instance, which is the KV producer
|
||||||
CUDA_VISIBLE_DEVICES=0 python3 ../vllm/entrypoints/disaggregated/worker.py \
|
CUDA_VISIBLE_DEVICES=0 python3 ../vllm/entrypoints/disaggregated/worker.py \
|
||||||
--model $MODEL \
|
--model $MODEL \
|
||||||
--connector-addr $CONNECTOR_ADDR \
|
--connector-addr $contoller_addr \
|
||||||
--worker-addr $PREFILL_WORKER_ADDR \
|
--worker-addr $PREFILL_WORKER_ADDR \
|
||||||
--max-model-len 100 \
|
--max-model-len 100 \
|
||||||
--gpu-memory-utilization 0.8 \
|
--gpu-memory-utilization 0.8 \
|
||||||
@ -62,7 +62,7 @@ CUDA_VISIBLE_DEVICES=0 python3 ../vllm/entrypoints/disaggregated/worker.py \
|
|||||||
# decoding instance, which is the KV consumer
|
# decoding instance, which is the KV consumer
|
||||||
CUDA_VISIBLE_DEVICES=1 python3 ../vllm/entrypoints/disaggregated/worker.py \
|
CUDA_VISIBLE_DEVICES=1 python3 ../vllm/entrypoints/disaggregated/worker.py \
|
||||||
--model $MODEL \
|
--model $MODEL \
|
||||||
--connector-addr $CONNECTOR_ADDR \
|
--connector-addr $contoller_addr \
|
||||||
--worker-addr $DECODE_WORKER_ADDR \
|
--worker-addr $DECODE_WORKER_ADDR \
|
||||||
--max-model-len 100 \
|
--max-model-len 100 \
|
||||||
--gpu-memory-utilization 0.8 \
|
--gpu-memory-utilization 0.8 \
|
||||||
@ -76,7 +76,7 @@ CUDA_VISIBLE_DEVICES=1 python3 ../vllm/entrypoints/disaggregated/worker.py \
|
|||||||
python3 ../vllm/entrypoints/disaggregated/connector.py \
|
python3 ../vllm/entrypoints/disaggregated/connector.py \
|
||||||
--port $PORT \
|
--port $PORT \
|
||||||
--model $MODEL \
|
--model $MODEL \
|
||||||
--connector-addr $CONNECTOR_ADDR \
|
--connector-addr $contoller_addr \
|
||||||
--prefill-addr $PREFILL_WORKER_ADDR \
|
--prefill-addr $PREFILL_WORKER_ADDR \
|
||||||
--decode-addr $DECODE_WORKER_ADDR
|
--decode-addr $DECODE_WORKER_ADDR
|
||||||
|
|
||||||
|
|||||||
@ -1,359 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from collections.abc import AsyncGenerator, Mapping
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import msgspec
|
|
||||||
import zmq
|
|
||||||
import zmq.asyncio
|
|
||||||
|
|
||||||
from vllm.config import DecodingConfig, ModelConfig
|
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
|
||||||
from vllm.disaggregated.protocol import (PDGenerationRequest,
|
|
||||||
PDGenerationResponse, PDRequestType,
|
|
||||||
PDResponseType)
|
|
||||||
from vllm.engine.protocol import EngineClient
|
|
||||||
from vllm.inputs.data import PromptType
|
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
||||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
|
||||||
from vllm.pooling_params import PoolingParams
|
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
|
||||||
from vllm.utils import Device
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
DEFAULT_MAX_TOKENS = 32000
|
|
||||||
|
|
||||||
|
|
||||||
class PDController(EngineClient):
|
|
||||||
"""
|
|
||||||
Controller that schedules work on the PDWorkers.
|
|
||||||
|
|
||||||
Conforms for the EngineClient protocol so it can
|
|
||||||
be wrapped with the OpenAI Server.
|
|
||||||
|
|
||||||
Two Phases:
|
|
||||||
* Send request to prefill worker, await ack.
|
|
||||||
* Send request to decode worker.
|
|
||||||
|
|
||||||
KVSync happens directly between Engines,
|
|
||||||
handled by vLLM KVCacheTransfer.
|
|
||||||
|
|
||||||
[ OpenAI Server ]
|
|
||||||
|
|
|
||||||
[ PDController ]
|
|
||||||
|
|
|
||||||
[ zmq ]
|
|
||||||
|
|
|
||||||
[ PDWorker ] [ PDWorker ]
|
|
||||||
|
|
|
||||||
[ Engine ] <---> [ Engine ]
|
|
||||||
|
|
||||||
After PR #12957, we will support xPyD, so we will
|
|
||||||
also need to implement a scheduler and service
|
|
||||||
discovery for the workers.
|
|
||||||
|
|
||||||
This PDController may be implemented as a K8s
|
|
||||||
controller. This is intended to be a prototype.
|
|
||||||
|
|
||||||
* TODO: better error handling
|
|
||||||
* TODO: support logprobs, multimodal, etc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, prefill_addr: str, decode_addr: str,
|
|
||||||
connector_addr: str, model_name: str):
|
|
||||||
# Request queues.
|
|
||||||
self.queues: dict[str, asyncio.Queue] = {}
|
|
||||||
|
|
||||||
# Serialization encoder.
|
|
||||||
self.encoder = msgspec.msgpack.Encoder()
|
|
||||||
|
|
||||||
# ZMQ communication.
|
|
||||||
# TODO: once https://github.com/vllm-project/vllm/pull/12957
|
|
||||||
# lands, do service discovery to scale out workers.
|
|
||||||
self.ctx = zmq.asyncio.Context()
|
|
||||||
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
|
|
||||||
self.to_decode.bind(f"{decode_addr}")
|
|
||||||
self.to_prefill = self.ctx.socket(zmq.constants.PUSH)
|
|
||||||
self.to_prefill.bind(f"{prefill_addr}")
|
|
||||||
self.connector_addr = connector_addr
|
|
||||||
self.decode_addr = decode_addr
|
|
||||||
self.prefill_addr = prefill_addr
|
|
||||||
|
|
||||||
# Background loops (started on first generate()).
|
|
||||||
self.output_handler: Optional[asyncio.Task] = None
|
|
||||||
self.log_running: Optional[asyncio.Task] = None
|
|
||||||
|
|
||||||
# Dummy: needed for EngineClient Protocol.
|
|
||||||
# TODO: refactor OAI Server to avoid needing this.
|
|
||||||
self.model_config = ModelConfig(model=model_name,
|
|
||||||
tokenizer=model_name,
|
|
||||||
tokenizer_mode="auto",
|
|
||||||
trust_remote_code=False,
|
|
||||||
dtype="auto",
|
|
||||||
task="generate",
|
|
||||||
seed=42)
|
|
||||||
|
|
||||||
# Dummy: needed for EngineClient Protocol.
|
|
||||||
# TODO: refactor OAI Server to avoid needing this.
|
|
||||||
self.tokenizer = TokenizerGroup(
|
|
||||||
**dict(tokenizer_id=self.model_config.tokenizer,
|
|
||||||
enable_lora=False,
|
|
||||||
max_num_seqs=1024,
|
|
||||||
max_loras=0,
|
|
||||||
max_input_length=None,
|
|
||||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
|
||||||
revision=self.model_config.tokenizer_revision,
|
|
||||||
truncation_side=self.model_config.truncation_side))
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
if (ctx := self.ctx) is not None:
|
|
||||||
ctx.destroy(linger=0)
|
|
||||||
if (task := self.log_running) is not None:
|
|
||||||
task.cancel()
|
|
||||||
if (task := self.output_handler) is not None:
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
ipc_paths = [self.connector_addr, self.decode_addr, self.prefill_addr]
|
|
||||||
for path in ipc_paths:
|
|
||||||
socket_path = path.replace("ipc://", "")
|
|
||||||
if os.path.exists(socket_path):
|
|
||||||
os.remove(socket_path)
|
|
||||||
|
|
||||||
async def _run_log_running(self):
|
|
||||||
logger.info("Running requests: %d", len(self.queues))
|
|
||||||
await asyncio.sleep(10.)
|
|
||||||
|
|
||||||
async def _run_output_handler(self):
|
|
||||||
"""
|
|
||||||
Pull responses from Decode + Prefill engines and
|
|
||||||
distribute back to the generate() tasks.
|
|
||||||
"""
|
|
||||||
decoder = msgspec.msgpack.Decoder(PDGenerationResponse)
|
|
||||||
|
|
||||||
socket: Optional[zmq.asyncio.Socket] = None
|
|
||||||
try:
|
|
||||||
socket = self.ctx.socket(zmq.constants.PULL)
|
|
||||||
socket.bind(self.connector_addr)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
res_type, res_data = await socket.recv_multipart()
|
|
||||||
if res_type == PDResponseType.FAILURE:
|
|
||||||
raise Exception("Failure Response from PDWorker.")
|
|
||||||
elif res_type == PDResponseType.GENERATION:
|
|
||||||
response = decoder.decode(res_data)
|
|
||||||
logger.debug("Got Response: %s", response.request_id)
|
|
||||||
self.queues[response.request_id].put_nowait(response)
|
|
||||||
else:
|
|
||||||
raise Exception("Unknown response type.")
|
|
||||||
except Exception as e:
|
|
||||||
# TODO: distinguish between fatal and non-fatal errors.
|
|
||||||
for q in self.queues.values():
|
|
||||||
q.put_nowait(e)
|
|
||||||
raise e
|
|
||||||
finally:
|
|
||||||
if socket is not None:
|
|
||||||
socket.close(linger=0)
|
|
||||||
|
|
||||||
async def _run_prefill(
|
|
||||||
self,
|
|
||||||
request: PDGenerationRequest,
|
|
||||||
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
|
|
||||||
):
|
|
||||||
# Send request to the prefill instance.
|
|
||||||
req_bytes = self.encoder.encode(request)
|
|
||||||
msg = (PDRequestType.GENERATION, req_bytes)
|
|
||||||
await self.to_prefill.send_multipart(msg, copy=False)
|
|
||||||
|
|
||||||
# Await completion of the prefill.
|
|
||||||
response = await q.get()
|
|
||||||
if isinstance(response, Exception):
|
|
||||||
raise response
|
|
||||||
logger.debug("Got Decode Response: %s", request.request_id)
|
|
||||||
|
|
||||||
async def _run_decode(
|
|
||||||
self,
|
|
||||||
request: PDGenerationRequest,
|
|
||||||
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
|
|
||||||
) -> AsyncGenerator[PDGenerationResponse]:
|
|
||||||
# Send request to the decode instance.
|
|
||||||
req_bytes = self.encoder.encode(request)
|
|
||||||
msg = (PDRequestType.GENERATION, req_bytes)
|
|
||||||
await self.to_decode.send_multipart(msg, copy=False)
|
|
||||||
|
|
||||||
# Iterate response queue and yield each response to caller.
|
|
||||||
finished = False
|
|
||||||
while not finished:
|
|
||||||
response = await q.get()
|
|
||||||
if isinstance(response, Exception):
|
|
||||||
raise response
|
|
||||||
logger.debug("Got Decode Response: %s", request.request_id)
|
|
||||||
finished = response.finish_reason is not None
|
|
||||||
yield response
|
|
||||||
|
|
||||||
def _to_request_output(
|
|
||||||
self,
|
|
||||||
response: PDGenerationResponse,
|
|
||||||
prompt_token_ids: list[int],
|
|
||||||
) -> RequestOutput:
|
|
||||||
finished = response.finish_reason is not None
|
|
||||||
return RequestOutput(
|
|
||||||
request_id=response.request_id,
|
|
||||||
prompt=None,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
prompt_logprobs=None,
|
|
||||||
outputs=[
|
|
||||||
CompletionOutput(index=0,
|
|
||||||
text=response.text,
|
|
||||||
token_ids=response.token_ids,
|
|
||||||
cumulative_logprob=None,
|
|
||||||
logprobs=None,
|
|
||||||
finish_reason=response.finish_reason,
|
|
||||||
stop_reason=response.stop_reason)
|
|
||||||
],
|
|
||||||
finished=finished,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self,
|
|
||||||
prompt: PromptType,
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
request_id: str,
|
|
||||||
lora_request: Optional[LoRARequest] = None,
|
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
||||||
priority: int = 0,
|
|
||||||
) -> AsyncGenerator[RequestOutput]:
|
|
||||||
# Start loops on first request.
|
|
||||||
if self.output_handler is None:
|
|
||||||
self.output_handler = asyncio.create_task(
|
|
||||||
self._run_output_handler())
|
|
||||||
self.log_running = asyncio.create_task(self._run_log_running())
|
|
||||||
|
|
||||||
# TODO: Expand to support the full matrix.
|
|
||||||
if "prompt_token_ids" not in prompt:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We currently only support TokensPrompt for P/D!")
|
|
||||||
if lora_request is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We currently do not support LoRA for P/D!")
|
|
||||||
if trace_headers is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We currently do not support tracing for P/D!")
|
|
||||||
if prompt_adapter_request is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We currently do not support prompt adapter for P/D!")
|
|
||||||
if priority != 0:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"We currently do not support priority for P/D!")
|
|
||||||
if request_id in self.queues:
|
|
||||||
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
|
||||||
|
|
||||||
# Queue to gather output from output_handler.
|
|
||||||
q = asyncio.Queue()
|
|
||||||
self.queues[request_id] = q
|
|
||||||
|
|
||||||
# (1) Perform the Prefill.
|
|
||||||
original_max_tokens = sampling_params.max_tokens
|
|
||||||
request = PDGenerationRequest(
|
|
||||||
request_id=request_id,
|
|
||||||
prompt_token_ids=prompt["prompt_token_ids"],
|
|
||||||
sampling_params=sampling_params)
|
|
||||||
request.sampling_params.max_tokens = 1
|
|
||||||
logger.debug("Sending Prefill: %s", request.request_id)
|
|
||||||
pd_response = await self._run_prefill(request, q)
|
|
||||||
|
|
||||||
# (2) Perform the Decodes.
|
|
||||||
logger.debug("Sending Decode: %s", request.request_id)
|
|
||||||
request.sampling_params.max_tokens = original_max_tokens
|
|
||||||
async for pd_response in self._run_decode(request, q):
|
|
||||||
yield self._to_request_output(pd_response,
|
|
||||||
prompt["prompt_token_ids"])
|
|
||||||
|
|
||||||
async def beam_search(
|
|
||||||
self,
|
|
||||||
prompt: PromptType,
|
|
||||||
request_id: str,
|
|
||||||
params: BeamSearchParams,
|
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
prompt: PromptType,
|
|
||||||
pooling_params: PoolingParams,
|
|
||||||
request_id: str,
|
|
||||||
lora_request: Optional[LoRARequest] = None,
|
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
|
||||||
priority: int = 0,
|
|
||||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_model_config(self) -> ModelConfig:
|
|
||||||
return self.model_config
|
|
||||||
|
|
||||||
async def get_decoding_config(self) -> DecodingConfig:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_tokenizer(
|
|
||||||
self,
|
|
||||||
lora_request: Optional[LoRARequest] = None,
|
|
||||||
) -> AnyTokenizer:
|
|
||||||
if lora_request is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"LoRA is not yet supported in the PDEngine.")
|
|
||||||
return self.tokenizer.get_lora_tokenizer(None)
|
|
||||||
|
|
||||||
async def is_tracing_enabled(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def do_log_stats(
|
|
||||||
self,
|
|
||||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
|
||||||
model_output: Optional[list[SamplerOutput]] = None,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def check_health(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def start_profile(self) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def stop_profile(self) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def reset_prefix_cache(self,
|
|
||||||
device: Optional[Device] = None) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def sleep(self, level: int = 1) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def wake_up(self) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def is_sleeping(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def errored(self) -> bool:
|
|
||||||
return False
|
|
||||||
@ -26,8 +26,8 @@ class PDWorker:
|
|||||||
"""
|
"""
|
||||||
PDWorker
|
PDWorker
|
||||||
* Wrapper around AsyncLLM to handle converting PRRequests
|
* Wrapper around AsyncLLM to handle converting PRRequests
|
||||||
to PDResponse and sending back to the PDClient.
|
to PDResponse and sending back to the PDConroller.
|
||||||
* Leverages ZMQ for communication with PDClient. We may
|
* Leverages ZMQ for communication with PDConroller. We may
|
||||||
expand this in the future.
|
expand this in the future.
|
||||||
"""
|
"""
|
||||||
# Engine.
|
# Engine.
|
||||||
@ -66,7 +66,7 @@ class PDWorker:
|
|||||||
async def run_busy_loop(self):
|
async def run_busy_loop(self):
|
||||||
"""
|
"""
|
||||||
main loop:
|
main loop:
|
||||||
1) wait for a request from the PDClient
|
1) wait for a request from the PDConroller
|
||||||
2) handle the request
|
2) handle the request
|
||||||
"""
|
"""
|
||||||
logger.info("PDWorker is ready To handle requests.")
|
logger.info("PDWorker is ready To handle requests.")
|
||||||
@ -116,7 +116,7 @@ class PDWorker:
|
|||||||
* 1) submit request to AsyncLLM
|
* 1) submit request to AsyncLLM
|
||||||
* 2) iterate the RequestOutputs
|
* 2) iterate the RequestOutputs
|
||||||
* 3) convert RequestOutput --> PDResponse
|
* 3) convert RequestOutput --> PDResponse
|
||||||
* 4) serialize and send to PDClient
|
* 4) serialize and send to PDConroller
|
||||||
"""
|
"""
|
||||||
request_id = req.request_id
|
request_id = req.request_id
|
||||||
|
|
||||||
@ -131,8 +131,9 @@ class PDWorker:
|
|||||||
# 3) Convert RequestOutput --> PDResponse.
|
# 3) Convert RequestOutput --> PDResponse.
|
||||||
response = PDGenerationResponse.from_request_output(request_output)
|
response = PDGenerationResponse.from_request_output(request_output)
|
||||||
|
|
||||||
# 4) Serialize and send to PDClient.
|
# 4) Serialize and send to PDConroller.
|
||||||
response_bytes = self.encoder.encode(response)
|
response_bytes = self.encoder.encode(response)
|
||||||
|
msg = [PDGenerationResponse.SUCCE, response_bytes]
|
||||||
logger.debug("Sending: %s", request_id)
|
logger.debug("Sending: %s", request_id)
|
||||||
await self.to_client.send(response_bytes, copy=False)
|
await self.to_client.send_multipart(msg, copy=False)
|
||||||
logger.debug("Sent: %s", request_id)
|
logger.debug("Sent: %s", request_id)
|
||||||
|
|||||||
@ -14,8 +14,8 @@ from vllm.outputs import RequestOutput
|
|||||||
|
|
||||||
|
|
||||||
class PDRequestType:
|
class PDRequestType:
|
||||||
GENERATION = b"generation"
|
GENERATION = b'\x00'
|
||||||
ABORT = b"abort"
|
ABORT = b'\x01'
|
||||||
|
|
||||||
|
|
||||||
class PDGenerationRequest(msgspec.Struct):
|
class PDGenerationRequest(msgspec.Struct):
|
||||||
@ -30,8 +30,8 @@ class PDAbortRequest(msgspec.Struct):
|
|||||||
|
|
||||||
|
|
||||||
class PDResponseType:
|
class PDResponseType:
|
||||||
GENERATION = b"generation"
|
GENERATION = b'\x00'
|
||||||
FAILURE = b"failure"
|
FAILURE = b'\x01'
|
||||||
|
|
||||||
|
|
||||||
class PDGenerationResponse(msgspec.Struct):
|
class PDGenerationResponse(msgspec.Struct):
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
"""
|
"""
|
||||||
Toy connector for prototyping.
|
Toy connector for prototyping.
|
||||||
|
|
||||||
When PDClient supports the protocol and we clean up the
|
When PDConroller supports the protocol and we clean up the
|
||||||
OpenAI Server, we can drop this in favor of vllm serve.
|
OpenAI Server, we can drop this in favor of vllm serve.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ import uvloop
|
|||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from vllm.disaggregated.pd_client import PDClient
|
from vllm.disaggregated.pd_contoller import PDController
|
||||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
ErrorResponse)
|
ErrorResponse)
|
||||||
@ -51,12 +51,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def pd_client_ctx(prefill_addr: str, decode_addr: str,
|
async def contoller_ctx(prefill_addr: str, decode_addr: str,
|
||||||
connector_addr: str,
|
contoller_addr: str,
|
||||||
model_name: str) -> AsyncIterator[PDClient]:
|
model_name: str) -> AsyncIterator[PDController]:
|
||||||
client = PDClient(prefill_addr, decode_addr, connector_addr, model_name)
|
c = PDController(prefill_addr, decode_addr, contoller_addr, model_name)
|
||||||
yield client
|
yield c
|
||||||
client.shutdown()
|
c.shutdown()
|
||||||
|
|
||||||
|
|
||||||
async def main(args, **uvicorn_kwargs):
|
async def main(args, **uvicorn_kwargs):
|
||||||
@ -69,12 +69,12 @@ async def main(args, **uvicorn_kwargs):
|
|||||||
# IPC Paths.
|
# IPC Paths.
|
||||||
prefill_addr = f"ipc://{args.prefill_addr}"
|
prefill_addr = f"ipc://{args.prefill_addr}"
|
||||||
decode_addr = f"ipc://{args.decode_addr}"
|
decode_addr = f"ipc://{args.decode_addr}"
|
||||||
connector_addr = f"ipc://{args.connector_addr}"
|
contoller_addr = f"ipc://{args.contoller_addr}"
|
||||||
|
|
||||||
# Start Engine.
|
# Start Engine.
|
||||||
async with pd_client_ctx(prefill_addr=prefill_addr,
|
async with contoller_ctx(prefill_addr=prefill_addr,
|
||||||
decode_addr=decode_addr,
|
decode_addr=decode_addr,
|
||||||
connector_addr=connector_addr,
|
contoller_addr=contoller_addr,
|
||||||
model_name=args.model) as engine_client:
|
model_name=args.model) as engine_client:
|
||||||
|
|
||||||
# Initialize App State.
|
# Initialize App State.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user