mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 00:54:31 +08:00
added files
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
parent
cf64b0e6a7
commit
2f29ae383a
323
vllm/disaggregated/pd_client.py
Normal file
323
vllm/disaggregated/pd_client.py
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from collections.abc import AsyncGenerator, Mapping
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
|
||||||
|
from vllm.config import DecodingConfig, ModelConfig
|
||||||
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
|
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
|
||||||
|
from vllm.inputs.data import PromptType
|
||||||
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
|
from vllm.utils import Device
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MAX_TOKENS = 32000
|
||||||
|
|
||||||
|
|
||||||
|
class PDClient:
|
||||||
|
"""
|
||||||
|
PDEngine:
|
||||||
|
Equiavlent of AsyncLLM for P/D. Assumes there is
|
||||||
|
a Prefill and Decode service already running.
|
||||||
|
|
||||||
|
* TODO: actually handle errors and failure.
|
||||||
|
* TODO: support more than just text input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefill_addr: str, decode_addr: str,
|
||||||
|
connector_addr: str, model_name: str):
|
||||||
|
# Request queues.
|
||||||
|
self.queues: dict[str, asyncio.Queue] = {}
|
||||||
|
|
||||||
|
# Serialization encoder.
|
||||||
|
self.encoder = msgspec.msgpack.Encoder()
|
||||||
|
|
||||||
|
# ZMQ communication.
|
||||||
|
self.ctx = zmq.asyncio.Context()
|
||||||
|
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
|
||||||
|
self.to_decode.bind(f"{decode_addr}")
|
||||||
|
self.to_prefill = self.ctx.socket(zmq.constants.PUSH)
|
||||||
|
self.to_prefill.bind(f"{prefill_addr}")
|
||||||
|
self.connector_addr = connector_addr
|
||||||
|
self.decode_addr = decode_addr
|
||||||
|
self.prefill_addr = prefill_addr
|
||||||
|
|
||||||
|
# Background loops (started on first generate()).
|
||||||
|
self.output_handler: Optional[asyncio.Task] = None
|
||||||
|
self.log_running: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
# Dummy: needed for EngineClient Protocol.
|
||||||
|
# TODO: refactor EngineClient to avoid needing this.
|
||||||
|
self.model_config = ModelConfig(model=model_name,
|
||||||
|
tokenizer=model_name,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
dtype="auto",
|
||||||
|
task="generate",
|
||||||
|
seed=42)
|
||||||
|
|
||||||
|
# Dummy: needed for EngineClient Protocol.
|
||||||
|
# TODO: refactor EngineClient to avoid needing this.
|
||||||
|
init_kwargs = dict(
|
||||||
|
tokenizer_id=self.model_config.tokenizer,
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=1024,
|
||||||
|
max_loras=0,
|
||||||
|
max_input_length=None,
|
||||||
|
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
revision=self.model_config.tokenizer_revision,
|
||||||
|
truncation_side=self.model_config.truncation_side)
|
||||||
|
self.tokenizer = TokenizerGroup(**init_kwargs)
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
if (ctx := self.ctx) is not None:
|
||||||
|
ctx.destroy(linger=0)
|
||||||
|
if (task := self.log_running) is not None:
|
||||||
|
task.cancel()
|
||||||
|
if (task := self.output_handler) is not None:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
ipc_paths = [self.connector_addr, self.decode_addr, self.prefill_addr]
|
||||||
|
for path in ipc_paths:
|
||||||
|
socket_path = path.replace("ipc://", "")
|
||||||
|
if os.path.exists(socket_path):
|
||||||
|
os.remove(socket_path)
|
||||||
|
|
||||||
|
async def _run_log_running(self):
|
||||||
|
logger.info("Running requests: %d", len(self.queues))
|
||||||
|
await asyncio.sleep(10.)
|
||||||
|
|
||||||
|
async def _run_output_handler(self):
|
||||||
|
"""
|
||||||
|
Pull responses from Decode + Prefill engines and
|
||||||
|
distribute back to the generate() tasks.
|
||||||
|
"""
|
||||||
|
decoder = msgspec.msgpack.Decoder(PDResponse)
|
||||||
|
|
||||||
|
socket: Optional[zmq.asyncio.Socket] = None
|
||||||
|
try:
|
||||||
|
socket = self.ctx.socket(zmq.constants.PULL)
|
||||||
|
socket.bind(self.connector_addr)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
reponse_bytes = await socket.recv()
|
||||||
|
response = decoder.decode(reponse_bytes)
|
||||||
|
logger.debug("Got Response: %s", response.request_id)
|
||||||
|
self.queues[response.request_id].put_nowait(response)
|
||||||
|
except:
|
||||||
|
# TODO: actually handle failure and shutdown.
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if socket is not None:
|
||||||
|
socket.close(linger=0)
|
||||||
|
|
||||||
|
async def _prefill(
|
||||||
|
self,
|
||||||
|
request: PDRequest,
|
||||||
|
q: asyncio.Queue[PDResponse],
|
||||||
|
):
|
||||||
|
# Send request to the prefill instance.
|
||||||
|
req_bytes = self.encoder.encode(request)
|
||||||
|
await self.to_prefill.send(req_bytes, copy=False)
|
||||||
|
|
||||||
|
# Wait for the prefill to be done.
|
||||||
|
response = await q.get()
|
||||||
|
assert response.request_id == request.request_id
|
||||||
|
if not response.success:
|
||||||
|
# TODO: actual error handling and shutdown.
|
||||||
|
raise Exception("Failed Prefill Request.")
|
||||||
|
|
||||||
|
async def _decode(
|
||||||
|
self,
|
||||||
|
request: PDRequest,
|
||||||
|
q: asyncio.Queue[PDResponse],
|
||||||
|
) -> AsyncGenerator[PDRequest]:
|
||||||
|
|
||||||
|
# Send request to the decode instance.
|
||||||
|
req_bytes = self.encoder.encode(request)
|
||||||
|
await self.to_decode.send(req_bytes, copy=False)
|
||||||
|
|
||||||
|
# Iterate response queue and yield each response to caller..
|
||||||
|
finished = False
|
||||||
|
while not finished:
|
||||||
|
response = await q.get()
|
||||||
|
if not response.success:
|
||||||
|
# TODO: actual error handling and shutdown.
|
||||||
|
raise Exception("Failed Decode Request.")
|
||||||
|
finished = response.finish_reason is not None
|
||||||
|
yield response
|
||||||
|
|
||||||
|
def _to_request_output(
|
||||||
|
self,
|
||||||
|
pd_response: PDResponse,
|
||||||
|
prompt_token_ids: list[int],
|
||||||
|
) -> RequestOutput:
|
||||||
|
finished = pd_response.finish_reason is not None
|
||||||
|
return RequestOutput(
|
||||||
|
request_id=pd_response.request_id,
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[
|
||||||
|
CompletionOutput(index=0,
|
||||||
|
text=pd_response.text,
|
||||||
|
token_ids=pd_response.token_ids,
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=pd_response.finish_reason,
|
||||||
|
stop_reason=pd_response.stop_reason)
|
||||||
|
],
|
||||||
|
finished=finished,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
prompt: PromptType,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
request_id: str,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
|
) -> AsyncGenerator[RequestOutput]:
|
||||||
|
# Start loops on first request.
|
||||||
|
if self.output_handler is None:
|
||||||
|
self.output_handler = asyncio.create_task(
|
||||||
|
self._run_output_handler())
|
||||||
|
self.log_running = asyncio.create_task(self._run_log_running())
|
||||||
|
|
||||||
|
# TODO: Expand to support the full matrix.
|
||||||
|
if "prompt_token_ids" not in prompt:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"We currently only support TokensPrompt for P/D!")
|
||||||
|
if lora_request is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"We currently do not support LoRA for P/D!")
|
||||||
|
if trace_headers is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"We currently do not support tracing for P/D!")
|
||||||
|
if prompt_adapter_request is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"We currently do not support prompt adapter for P/D!")
|
||||||
|
if priority != 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"We currently do not support priority for P/D!")
|
||||||
|
if request_id in self.queues:
|
||||||
|
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
||||||
|
|
||||||
|
# Queue to gather output from output_handler.
|
||||||
|
q: asyncio.Queue[PDResponse] = asyncio.Queue()
|
||||||
|
self.queues[request_id] = q
|
||||||
|
|
||||||
|
# (1) Perform the Prefill.
|
||||||
|
original_max_tokens = sampling_params.max_tokens
|
||||||
|
prompt_token_ids = prompt["prompt_token_ids"]
|
||||||
|
request = PDRequest(request_id=request_id,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
request.sampling_params.max_tokens = 1
|
||||||
|
logger.debug("Sending Prefill: %s", request.request_id)
|
||||||
|
pd_response = await self._prefill(request, q)
|
||||||
|
|
||||||
|
# (2) Perform the Decodes.
|
||||||
|
logger.debug("Sending Decode: %s", request.request_id)
|
||||||
|
request.sampling_params.max_tokens = original_max_tokens
|
||||||
|
async for pd_response in self._decode(request, q):
|
||||||
|
logger.debug("Got Decode: %s", request.request_id)
|
||||||
|
yield self._to_request_output(pd_response, prompt_token_ids)
|
||||||
|
|
||||||
|
async def beam_search(
|
||||||
|
self,
|
||||||
|
prompt: PromptType,
|
||||||
|
request_id: str,
|
||||||
|
params: BeamSearchParams,
|
||||||
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
prompt: PromptType,
|
||||||
|
pooling_params: PoolingParams,
|
||||||
|
request_id: str,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
priority: int = 0,
|
||||||
|
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def abort(self, request_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_model_config(self) -> ModelConfig:
|
||||||
|
return self.model_config
|
||||||
|
|
||||||
|
async def get_decoding_config(self) -> DecodingConfig:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_tokenizer(
|
||||||
|
self,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
) -> AnyTokenizer:
|
||||||
|
if lora_request is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"LoRA is not yet supported in the PDEngine.")
|
||||||
|
return self.tokenizer.get_lora_tokenizer(None)
|
||||||
|
|
||||||
|
async def is_tracing_enabled(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def do_log_stats(
|
||||||
|
self,
|
||||||
|
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||||
|
model_output: Optional[list[SamplerOutput]] = None,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def check_health(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def start_profile(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def stop_profile(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def reset_prefix_cache(self,
|
||||||
|
device: Optional[Device] = None) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def sleep(self, level: int = 1) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def wake_up(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def is_sleeping(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def errored(self) -> bool:
|
||||||
|
return False
|
||||||
63
vllm/disaggregated/protocol.py
Normal file
63
vllm/disaggregated/protocol.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
|
||||||
|
# NOTE FOR DEVELOPERS:
|
||||||
|
# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE
|
||||||
|
# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE
|
||||||
|
# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS.
|
||||||
|
|
||||||
|
|
||||||
|
class PDRequestType:
|
||||||
|
GENERATION = b"generation"
|
||||||
|
ABORT = b"abort"
|
||||||
|
|
||||||
|
|
||||||
|
class PDGenerationRequest(msgspec.Struct):
|
||||||
|
request_id: str
|
||||||
|
prompt_token_ids: list[int]
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
# TODO: support multimodal inputs.
|
||||||
|
|
||||||
|
|
||||||
|
class PDAbortRequest(msgspec.Struct):
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class PDResponseType:
|
||||||
|
GENERATION = b"generation"
|
||||||
|
FAILURE = b"failure"
|
||||||
|
|
||||||
|
|
||||||
|
class PDGenerationResponse(msgspec.Struct):
|
||||||
|
request_id: str
|
||||||
|
text: str
|
||||||
|
token_ids: list[int]
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Optional[str] = None
|
||||||
|
# TODO: support full protocol.
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_request_output(
|
||||||
|
self, request_output: RequestOutput) -> "PDGenerationResponse":
|
||||||
|
assert len(request_output.outputs) == 1, "Only support N=1 right now."
|
||||||
|
out = request_output.outputs[0]
|
||||||
|
return PDGenerationResponse(
|
||||||
|
request_id=request_output.request_id,
|
||||||
|
text=out.text,
|
||||||
|
token_ids=out.token_ids,
|
||||||
|
finish_reason=out.finish_reason,
|
||||||
|
stop_reason=out.stop_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PDGenerationFailure(msgspec.Struct):
|
||||||
|
request_id: str
|
||||||
|
error_message: str
|
||||||
|
engine_dead: bool
|
||||||
Loading…
x
Reference in New Issue
Block a user